diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 32775b53..42007d8d 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -107,7 +107,6 @@ jobs: # ------------------------------------ v1 spec decode test ------------------------------------ # pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py - pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py e2e-2-cards: @@ -170,10 +169,6 @@ jobs: if: ${{ inputs.type == 'light' }} run: | pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP - pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_qwen3_moe_with_torchair - pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_torchair - pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_torchair_v1scheduler - pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py::test_e2e_deepseekv2lite_with_nz - name: Run vllm-project/vllm-ascend test (full) env: @@ -183,7 +178,6 @@ jobs: run: | pytest -sv tests/e2e/multicard/test_quantization.py pytest -sv tests/e2e/multicard/test_aclgraph_capture_replay.py - pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py diff --git a/.github/workflows/vllm_ascend_test_nightly_a2.yaml b/.github/workflows/vllm_ascend_test_nightly_a2.yaml index 01ee56ca..a18cd5d4 100644 --- a/.github/workflows/vllm_ascend_test_nightly_a2.yaml +++ b/.github/workflows/vllm_ascend_test_nightly_a2.yaml @@ -127,9 +127,6 @@ jobs: - name: multi-node-deepseek-dp config_file_path: DeepSeek-R1-W8A8-A2.yaml size: 2 - - name: multi-node-deepseek-dp-torchair - config_file_path: DeepSeek-R1-W8A8-A2-torchair.yaml - size: 2 uses: ./.github/workflows/_e2e_nightly_multi_node.yaml with: soc_version: a2 diff --git a/.github/workflows/vllm_ascend_test_pr_light.yaml b/.github/workflows/vllm_ascend_test_pr_light.yaml index ad4cbe6c..264e0a62 100644 --- a/.github/workflows/vllm_ascend_test_pr_light.yaml +++ b/.github/workflows/vllm_ascend_test_pr_light.yaml @@ -134,8 +134,6 @@ jobs: run: | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/arm64-linux/devlib pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \ - --ignore tests/ut/torchair/models/test_torchair_deepseek_mtp.py \ - --ignore tests/ut/torchair/models/test_torchair_deepseek_v2.py \ --ignore tests/ut/model_loader/netloader/test_netloader_elastic.py \ --ignore tests/ut/kv_connector/test_remote_prefill_lifecycle.py \ --ignore tests/ut/kv_connector/test_remote_decode_lifecycle.py \ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4440bb5b..82bde178 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: codespell args: [ --toml, pyproject.toml, - '--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', + '--skip', 'csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml', '-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND' ] additional_dependencies: diff --git a/docs/source/developer_guide/contribution/testing.md b/docs/source/developer_guide/contribution/testing.md index b5ae289d..20206979 100644 --- a/docs/source/developer_guide/contribution/testing.md +++ b/docs/source/developer_guide/contribution/testing.md @@ -251,7 +251,6 @@ This will reproduce the E2E test. See [vllm_ascend_test.yaml](https://github.com - Offline test example: [`tests/e2e/singlecard/test_offline_inference.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_offline_inference.py) - Online test examples: [`tests/e2e/singlecard/test_prompt_embedding.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_prompt_embedding.py) - Correctness test example: [`tests/e2e/singlecard/test_aclgraph.py`](https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_aclgraph.py) -- Reduced Layer model test example: [test_torchair_graph_mode.py - DeepSeek-V3-Pruning](https://github.com/vllm-project/vllm-ascend/blob/20767a043cccb3764214930d4695e53941de87ec/tests/e2e/multicard/test_torchair_graph_mode.py#L48) The CI resource is limited, and you might need to reduce the number of layers of a model. Below is an example of how to generate a reduced layer model: 1. Fork the original model repo in modelscope. All the files in the repo except for weights are required. diff --git a/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md b/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md index 04bde6fe..f8503e78 100644 --- a/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md +++ b/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md @@ -6,7 +6,7 @@ MTP boosts inference performance by parallelizing the prediction of multiple tok ## How to Use MTP To enable MTP for DeepSeek-V3 models, add the following parameter when starting the service: ---speculative_config ' {"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} ' +--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} ' - `num_speculative_tokens`: The number of speculative tokens which enable model to predict multiple tokens at once, if provided. It will default to the number in the draft model config if present, otherwise, it is required. - `disable_padded_drafter_batch`: Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the MTP method of speculation, default is False. @@ -74,21 +74,18 @@ If the bonus token is accepted, the MTP model performs inference for (num_specul ### Method Validation -- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and deepseek_mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided. +- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided. ``` def get_spec_decode_method(method, vllm_config, device, - runner, - is_torchair_graph=False): + runner): if method == "ngram": return NgramProposer(vllm_config, device, runner) elif method in ["eagle", "eagle3"]: return EagleProposer(vllm_config, device, runner) - elif method == 'deepseek_mtp': - if is_torchair_graph: - return TorchairMtpProposer(vllm_config, device, runner) + elif method == 'mtp': return MtpProposer(vllm_config, device, runner) else: raise ValueError("Unknown speculative decoding method: " diff --git a/docs/source/tutorials/DeepSeek-V3.1.md b/docs/source/tutorials/DeepSeek-V3.1.md index fd172408..ccaf4ce7 100644 --- a/docs/source/tutorials/DeepSeek-V3.1.md +++ b/docs/source/tutorials/DeepSeek-V3.1.md @@ -128,9 +128,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ --trust-remote-code \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.92 \ ---speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \ +--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \ --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ ---additional-config '{"torchair_graph_config":{"enabled":false}}' ``` ### Multi-node Deployment @@ -190,9 +189,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ --trust-remote-code \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.94 \ ---speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \ +--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \ --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ ---additional-config '{"torchair_graph_config":{"enabled":false}}' ``` **Node 1** @@ -247,9 +245,8 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ --trust-remote-code \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.94 \ ---speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \ +--speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \ --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ ---additional-config '{"torchair_graph_config":{"enabled":false}}' ``` ### Prefill-Decode Disaggregation @@ -421,7 +418,7 @@ vllm serve /weights/DeepSeek-V3.1_w8a8mix_mtp \ --gpu-memory-utilization 0.9 \ --quantization ascend \ --no-enable-prefix-caching \ - --speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' \ + --speculative-config '{"num_speculative_tokens": 1, "method": "mtp"}' \ --additional-config '{"recompute_scheduler_enable":true,"enable_shared_expert_dp": true}' \ --kv-transfer-config \ '{"kv_connector": "MooncakeConnector", diff --git a/docs/source/tutorials/DeepSeek-V3.2-Exp.md b/docs/source/tutorials/DeepSeek-V3.2-Exp.md index 79518dc4..132e7efc 100644 --- a/docs/source/tutorials/DeepSeek-V3.2-Exp.md +++ b/docs/source/tutorials/DeepSeek-V3.2-Exp.md @@ -173,8 +173,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \ --enable-expert-parallel \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.92 \ ---additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' +--gpu-memory-utilization 0.92 ``` ### Multi-node Deployment @@ -225,8 +224,7 @@ vllm serve /root/.cache/Modelers_Park/DeepSeek-V3.2-Exp \ --max-num-batched-tokens 17450 \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.9 \ ---additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' +--gpu-memory-utilization 0.9 ``` **Node 1** @@ -269,8 +267,7 @@ vllm serve /root/.cache/Modelers_Park/DeepSeek-V3.2-Exp \ --enable-expert-parallel \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.92 \ ---additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' +--gpu-memory-utilization 0.92 ``` :::: @@ -316,8 +313,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \ --trust-remote-code \ --quantization ascend \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.9 \ ---additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' +--gpu-memory-utilization 0.9 ``` **Node 1** @@ -362,8 +358,7 @@ vllm serve vllm-ascend/DeepSeek-V3.2-Exp-W8A8 \ --trust-remote-code \ --quantization ascend \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.92 \ ---additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' +--gpu-memory-utilization 0.92 ``` :::: diff --git a/docs/source/tutorials/multi_node.md b/docs/source/tutorials/multi_node.md index fbfbbd74..5a30c716 100644 --- a/docs/source/tutorials/multi_node.md +++ b/docs/source/tutorials/multi_node.md @@ -136,8 +136,7 @@ vllm serve vllm-ascend/DeepSeek-V3.1-W8A8 \ --max-num-batched-tokens 8192 \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.9 \ ---additional-config '{"torchair_graph_config":{"enabled":true}}' +--gpu-memory-utilization 0.9 ``` **Node 1** @@ -181,8 +180,7 @@ vllm serve vllm-ascend/DeepSeek-V3.1-W8A8 \ --enable-expert-parallel \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.92 \ ---additional-config '{"torchair_graph_config":{"enabled":true}}' +--gpu-memory-utilization 0.92 ``` The deployment view looks like: diff --git a/docs/source/tutorials/multi_node_kimi.md b/docs/source/tutorials/multi_node_kimi.md index d53ef09d..f37d9bf4 100644 --- a/docs/source/tutorials/multi_node_kimi.md +++ b/docs/source/tutorials/multi_node_kimi.md @@ -92,8 +92,7 @@ vllm serve /home/cache/weights/Kimi-K2-Instruct-W8A8 \ --max-num-batched-tokens 8192 \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.9 \ ---additional-config '{"torchair_graph_config":{"enabled":true}}' +--gpu-memory-utilization 0.9 ``` **Node 1** @@ -136,8 +135,7 @@ vllm serve /home/cache/weights/Kimi-K2-Instruct-W8A8 \ --enable-expert-parallel \ --trust-remote-code \ --no-enable-prefix-caching \ ---gpu-memory-utilization 0.92 \ ---additional-config '{"torchair_graph_config":{"enabled":true}}' +--gpu-memory-utilization 0.92 ``` The deployment view looks like: diff --git a/docs/source/tutorials/multi_npu_moge.md b/docs/source/tutorials/multi_npu_moge.md index 91806ba7..91e22845 100644 --- a/docs/source/tutorials/multi_npu_moge.md +++ b/docs/source/tutorials/multi_npu_moge.md @@ -153,12 +153,7 @@ if __name__ == "__main__": enable_expert_parallel=True, distributed_executor_backend="mp", max_model_len=1024, - trust_remote_code=True, - additional_config={ - 'torchair_graph_config': { - 'enabled': True, - } - }) + trust_remote_code=True) outputs = llm.generate(prompts, sampling_params) for output in outputs: diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 129c4997..c975c6b1 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -27,7 +27,6 @@ The following table lists additional configuration options available in vLLM Asc | Name | Type | Default | Description | |-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------| | `xlite_graph_config` | dict | `{}` | Configuration options for xlite graph mode | -| `torchair_graph_config` | dict | `{}` | Configuration options for torchair graph mode | | `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch | | `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. | | `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. | @@ -52,21 +51,6 @@ The details of each configuration option are as follows: | `enabled` | bool | `False` | Whether to enable xlite graph mode. Currently only Llama or Qwen dense series models are supported. | | `full_mode` | bool | `False` | Whether to enable xlite for both the prefill and decode stages. By default, xlite is only enabled for the decode stage. | -**torchair_graph_config** - -| Name | Type | Default | Description | -| ---- | ---- | ------- | ----------- | -| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported. | -| `mode` | str | `None` | When using reduce-overhead mode for torchair, it needs to be set. | -| `enable_multistream_mla`| bool | `False` | Whether to put vector operators of MLA to another stream. This option only takes effect on models using MLA (for example, DeepSeek). | -| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization. | -| `enable_frozen_parameter` | bool | `True` | Whether to fix the memory address of weights during inference to reduce the input address refresh time during graph execution. | -| `use_cached_graph` | bool | `False` | Whether to use cached graph. | -| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache. | -| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty. | -| `enable_kv_nz`| bool | `False` | Whether to enable KV Cache NZ layout. This option only takes effect on models using MLA (for example, DeepSeek). | -| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.| - **weight_prefetch_config** | Name | Type | Default | Description | @@ -80,13 +64,6 @@ An example of additional configuration is as follows: ``` { - "torchair_graph_config": { - "enabled": True, - "use_cached_graph": True, - "graph_batch_sizes": [1, 2, 4, 8], - "graph_batch_sizes_init": False, - "enable_kv_nz": False - }, "weight_prefetch_config": { "enabled": True, "prefetch_ratio": { diff --git a/docs/source/user_guide/feature_guide/graph_mode.md b/docs/source/user_guide/feature_guide/graph_mode.md index 81958e45..76868949 100644 --- a/docs/source/user_guide/feature_guide/graph_mode.md +++ b/docs/source/user_guide/feature_guide/graph_mode.md @@ -10,9 +10,8 @@ This guide provides instructions for using Ascend Graph Mode with vLLM Ascend. P From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model. -There are three kinds for graph mode supported by vLLM Ascend: +There are two kinds for graph mode supported by vLLM Ascend: - **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, Qwen and Deepseek series models are well tested. -- **TorchAirGraph**: This is the GE graph mode. In v0.9.1rc1, only DeepSeek series models are supported. - **XliteGraph**: This is the euler xlite graph mode. In v0.11.0, only Llama and Qwen dense serise models are supported. ## Using ACLGraph @@ -35,29 +34,6 @@ Online example: vllm serve Qwen/Qwen2-7B-Instruct ``` -## Using TorchAirGraph - -If you want to run DeepSeek series models with the graph mode, you should use [TorchAirGraph](https://www.hiascend.com/document/detail/zh/Pytorch/700/modthirdparty/torchairuseguide/torchair_0002.html). In this case, additional configuration is required. - -Offline example: - -```python -import os -from vllm import LLM - -# TorchAirGraph only works without chunked-prefill now -model = LLM(model="path/to/DeepSeek-R1-0528", additional_config={"torchair_graph_config": {"enabled": True}}) -outputs = model.generate("Hello, how are you?") -``` - -Online example: - -```shell -vllm serve path/to/DeepSeek-R1-0528 --additional-config='{"torchair_graph_config": {"enabled": true}}' -``` - -You can find more details about additional configuration [here](../configuration/additional_config.md). - ## Using XliteGraph If you want to run Llama or Qwen dense series models with xlite graph mode, please install xlite, and set xlite_graph_config. @@ -87,7 +63,7 @@ You can find more details abort xlite [here](https://gitee.com/openeuler/GVirt/b ## Fallback to the Eager Mode -If `ACLGraph`, `TorchAirGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode. +If `ACLGraph` and `XliteGraph` all fail to run, you should fallback to the eager mode. Offline example: diff --git a/docs/source/user_guide/feature_guide/quantization.md b/docs/source/user_guide/feature_guide/quantization.md index 8a6e3676..8632b74f 100644 --- a/docs/source/user_guide/feature_guide/quantization.md +++ b/docs/source/user_guide/feature_guide/quantization.md @@ -104,22 +104,3 @@ First, make sure you specify `ascend` as the quantization method. Second, check ### 2. How to solve the error "Could not locate the configuration_deepseek.py"? Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_20260624` ModelSlim, where the missing configuration_deepseek.py error has been fixed. - -### 3. What should be considered when converting DeepSeek series models with ModelSlim? - -When the MLA portion of the weights used the `W8A8_DYNAMIC` quantization with the torchair graph mode enabled, modify the configuration file in the CANN package to prevent incorrect inference results. - -The operation steps are as follows: - -1. Search in the CANN package directory, for example: -find /usr/local/Ascend/ -name fusion_config.json - -2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows: - -```bash -{ - "Switch":{ - "GraphFusion":{ - "AddRmsNormDynamicQuantFusionPass":"off", - "MultiAddRmsNormDynamicQuantFusionPass":"off", -``` diff --git a/examples/run_dp_server.sh b/examples/run_dp_server.sh index c6ff7aa6..0607b48e 100644 --- a/examples/run_dp_server.sh +++ b/examples/run_dp_server.sh @@ -27,5 +27,4 @@ vllm serve Qwen/Qwen1.5-MoE-A2.7B \ --max-num-batched-tokens 4096 \ --gpu-memory-utilization 0.9 \ --trust-remote-code \ - --enforce-eager \ - --additional-config '{"torchair_graph_config":{"enabled":false, "use_cached_graph":false}}' + --enforce-eager diff --git a/tests/e2e/310p/test_offline_inference_parallel_310p.py b/tests/e2e/310p/test_offline_inference_parallel_310p.py deleted file mode 100644 index 7ba7ef73..00000000 --- a/tests/e2e/310p/test_offline_inference_parallel_310p.py +++ /dev/null @@ -1,59 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -import pytest -import vllm # noqa: F401 - -import vllm_ascend # noqa: F401 -from tests.e2e.conftest import VllmRunner - -# Pangu local model path -MODELS = [ - "IntervitensInc/pangu-pro-moe-model", -] -# set additional config for ascend scheduler and torchair graph -ADDITIONAL_CONFIG = [{ - "additional_config": { - "torchair_graph_config": { - "enabled": True - } - } -}] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float16"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enfore_eager", [True, False]) -@pytest.mark.parametrize("additional_config", ADDITIONAL_CONFIG) -def test_pangu_model(model: str, dtype: str, max_tokens: int, - enfore_eager: bool, additional_config: dict) -> None: - if enfore_eager: - additional_config = {} - example_prompts = [ - "Hello, my name is", - "The future of AI is", - ] - - with VllmRunner(model, - tensor_parallel_size=4, - dtype=dtype, - max_model_len=1024, - enforce_eager=True, - enable_expert_parallel=True, - additional_config=additional_config, - distributed_executor_backend="mp") as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 1380c49e..67c87332 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -78,9 +78,6 @@ def test_models_distributed_DeepSeek_multistream_moe(): tensor_parallel_size=2, distributed_executor_backend="mp", additional_config={ - "torchair_graph_config": { - "enabled": True, - }, "enable_multistream_moe": True, "refresh": True, }, @@ -144,17 +141,12 @@ def test_models_distributed_DeepSeek_W4A8DYNAMIC(model): "Hello, my name is", ] max_tokens = 5 - with VllmRunner( - snapshot_download(model), - dtype="auto", - tensor_parallel_size=2, - quantization="ascend", - enforce_eager=True, - enable_expert_parallel=True, - additional_config={"torchair_graph_config": { - "enabled": False, - }}, - ) as vllm_model: + with VllmRunner(snapshot_download(model), + dtype="auto", + tensor_parallel_size=2, + quantization="ascend", + enforce_eager=True, + enable_expert_parallel=True) as vllm_model: vllm_model.generate_greedy(prompts, max_tokens) diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py deleted file mode 100644 index 3472051e..00000000 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ /dev/null @@ -1,290 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -"""Compare the short outputs of HF and vLLM when using greedy sampling. - -Run `pytest tests/multicard/test_torchair_graph_mode.py`. -""" -import os -from typing import Dict - -import pytest - -from tests.e2e.conftest import VllmRunner - -os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" - - -def _deepseek_torchair_test_fixture( - additional_config: Dict, - *, - tensor_parallel_size=2, - use_v1_schduler=False, -): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - kwargs = {} - if not use_v1_schduler: - kwargs = { - "refresh": True, - } - additional_config.update(**kwargs) - - with VllmRunner( - "vllm-ascend/DeepSeek-V3-Pruning", - dtype="half", - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend="mp", - additional_config=additional_config, - ) as vllm_model: - # use greedy sampler to make sure the generated results are fix - vllm_output = vllm_model.generate_greedy(example_prompts, 5) - - # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of - # DeepSeek-V3 with 2 hidden layers, thus the golden results seems - # inaccurate. This will only change if accuracy improves with the - # official weights of DeepSeek-V3. - golden_results = [ - 'Hello, my name is下载早点向前很有่อง', - 'The president of the United States isSender)## physiological Albany', - 'The capital of France is Rocky转角 hospitalizedinterval sparked', - 'The future of AI is её asegο BIOS一扫', - ] - - assert len(golden_results) == len(vllm_output) - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") - - -def test_e2e_deepseekv3_with_torchair(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - _deepseek_torchair_test_fixture(additional_config) - - -def test_e2e_deepseekv3_with_torchair_ms_mla(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - "enable_multistream_mla": True, - }, - } - _deepseek_torchair_test_fixture(additional_config) - - -@pytest.mark.skip("accuracy test failed. Fix me") -def test_e2e_deepseekv3_with_torchair_v1scheduler(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - _deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True) - - -def _pangu_torchair_test_fixture( - additional_config: Dict, - *, - tensor_parallel_size=2, -): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - # torchair is only work without chunked-prefill now - kwargs = { - "refresh": True, - } - additional_config.update(**kwargs) - - with VllmRunner( - "vllm-ascend/pangu-pro-moe-pruing", - dtype="half", - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend="mp", - additional_config=additional_config, - enable_expert_parallel=True, - ) as vllm_model: - # use greedy sampler to make sure the generated results are fix - vllm_output = vllm_model.generate_greedy(example_prompts, 5) - - # NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE - # with 2 hidden layers, thus the golden results seems inaccurate. - # This will only change if accuracy changes with the official weights - # of PanguProMoE. - golden_results = [ - 'Hello, my name is Remempondeprecatedmiot忱', - 'The president of the United States is Remem下的一个 rever ceremoni Segnali', - 'The capital of France is Rememvoud administrativ Remem投', - 'The future of AI isotope Segnali Zoeken精细化 supus', - ] - - assert len(golden_results) == len(vllm_output) - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") - - -@pytest.mark.skip("skipping test_e2e_pangu_with_torchair") -def test_e2e_pangu_with_torchair(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - _pangu_torchair_test_fixture(additional_config) - - -def _qwen_torchair_test_fixture( - model, - tp, - enable_expert_parallel, -): - # The current access control does not support 16 cards, - # so the MC2 operator in Qwen's graph mode cannot run. - # Once 16-card support is available, - # this e2e can be switched to graph mode. - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "refresh": True, - } - - with VllmRunner( - model, - dtype="half", - tensor_parallel_size=tp, - distributed_executor_backend="mp", - enforce_eager=True, - additional_config=additional_config, - enable_expert_parallel=enable_expert_parallel, - ) as vllm_model: - # use greedy sampler to make sure the generated results are fix - vllm_output = vllm_model.generate_greedy(example_prompts, 5) - - # NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE - # with 2 hidden layers, thus the golden results seems inaccurate. - # This will only change if accuracy changes with the official weights - # of PanguProMoE. - golden_results = [ - 'Hello, my name is Remempondeprecatedmiot忱', - 'The president of the United States is Remem下的一个 rever ceremoni Segnali', - 'The capital of France is Rememvoud administrativ Remem投', - 'The future of AI isotope Segnali Zoeken精细化 supus', - ] - - assert len(golden_results) == len(vllm_output) - for i in range(len(vllm_output)): - print(f"Generated text: {vllm_output[i][1]!r}") - - -def test_e2e_qwen2_with_torchair(): - _qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False) - - -def test_e2e_qwen3_moe_with_torchair(): - _qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True) - - -# test deepseek-v2-lite -def _deepseek_v2_lite_torchair_test_fixure( - additional_config: Dict, - *, - tensor_parallel_size=2, - use_v1_schduler=False, -): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - kwargs = {} - if not use_v1_schduler: - kwargs = { - "refresh": True, - } - additional_config.update(**kwargs) - - with VllmRunner( - "deepseek-ai/DeepSeek-V2-Lite", - dtype="half", - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend="mp", - additional_config=additional_config, - ) as vllm_model: - vllm_output = vllm_model.generate_greedy(example_prompts, 5) - - # NOTE: deepseek-ai/DeepSeek-V2-Lite is a random weight of - # DeepSeek-V2-Lite with 2 hidden layers, thus the golden results seems - # inaccurate. This will only change if accuracy improves with the - # official weights of DeepSeek-V2-Lite. - - for i in range(len(vllm_output)): - generated_text = vllm_output[i][1] - assert len( - generated_text.strip()) > 0, f"The {i}-th output is null, failed" - - -def test_e2e_deepseekv2lite_with_torchair(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - _deepseek_v2_lite_torchair_test_fixure(additional_config) - - -def test_e2e_deepseekv2lite_with_torchair_v1scheduler(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - _deepseek_v2_lite_torchair_test_fixure(additional_config, - use_v1_schduler=True) - - -# kv_cache enable e2e test -def test_e2e_deepseekv2lite_with_nz(): - additional_config = { - "torchair_graph_config": { - "enabled": True, - "enable_kv_nz": True, - }, - } - _deepseek_v2_lite_torchair_test_fixure(additional_config) diff --git a/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py b/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py index 7a782258..539d62d7 100644 --- a/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py +++ b/tests/e2e/nightly/features/test_mtpx_deepseek_r1_0528_w8a8.py @@ -73,7 +73,6 @@ async def test_models(model: str, mode: str) -> None: "VLLM_RPC_TIMEOUT": "3600000", "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": "3600000" } - additional_config: dict[str, Any] = {} speculative_config = {"num_speculative_tokens": 2, "method": "mtp"} compilation_config = { "cudagraph_capture_sizes": [56], @@ -104,7 +103,6 @@ async def test_models(model: str, mode: str) -> None: ["--speculative-config", json.dumps(speculative_config)]) server_args.extend(["--gpu-memory-utilization", "0.92"]) - additional_config["torchair_graph_config"] = {"enabled": True} aisbench_cases = aisbench_gsm8k if mode == "mtp3": env_dict["HCCL_OP_EXPANSION_MODE"] = "AIV" @@ -117,9 +115,7 @@ async def test_models(model: str, mode: str) -> None: server_args.extend( ["--compilation-config", json.dumps(compilation_config)]) - additional_config["torchair_graph_config"] = {"enabled": False} aisbench_cases = aisbench_aime - server_args.extend(["--additional-config", json.dumps(additional_config)]) request_keyword_args: dict[str, Any] = { **api_keyword_args, } diff --git a/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py b/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py index 3776e49c..8a281a26 100644 --- a/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py +++ b/tests/e2e/nightly/features/test_prefix_cache_deepseek_r1_0528_w8a8.py @@ -74,13 +74,6 @@ async def test_models(model: str) -> None: "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True", } additional_config = { - "torchair_graph_config": { - "enabled": True, - "enable_multistream_moe": False, - "enable_multistream_mla": True, - "graph_batch_size": [16], - "use_cached_graph": True - }, "chunked_prefill_for_mla": True, "enable_weight_nz_layout": True } diff --git a/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py b/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py index 7a76a4a1..dcb83b14 100644 --- a/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py +++ b/tests/e2e/nightly/models/test_deepseek_r1_0528_w8a8.py @@ -29,7 +29,6 @@ MODELS = [ ] MODES = [ - "torchair", "single", "aclgraph", "aclgraph_mlapo", @@ -78,13 +77,6 @@ async def test_models(model: str, mode: str) -> None: } speculative_config = {"num_speculative_tokens": 1, "method": "mtp"} additional_config = { - "torchair_graph_config": { - "enabled": True, - "enable_multistream_moe": False, - "enable_multistream_mla": True, - "graph_batch_sizes": [16], - "use_cached_graph": True - }, "chunked_prefill_for_mla": True, "enable_weight_nz_layout": True } @@ -99,12 +91,8 @@ async def test_models(model: str, mode: str) -> None: ] if mode == "single": server_args.append("--enforce-eager") - additional_config["torchair_graph_config"] = {"enabled": False} - if mode == "aclgraph": - additional_config["torchair_graph_config"] = {"enabled": False} if mode == "aclgraph_mlapo": env_dict["VLLM_ASCEND_ENABLE_MLAPO"] = "1" - additional_config["torchair_graph_config"] = {"enabled": False} server_args.extend(["--additional-config", json.dumps(additional_config)]) request_keyword_args: dict[str, Any] = { **api_keyword_args, diff --git a/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py b/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py index 3f504ae9..498eadfe 100644 --- a/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py +++ b/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py @@ -68,9 +68,6 @@ async def test_models(model: str) -> None: "cudagraph_mode": "FULL_DECODE_ONLY" } additional_config: dict[str, Any] = { - "torchair_graph_config": { - "enabled": True - }, "enable_shared_expert_dp": False, "multistream_overlap_shared_expert": False, "dynamic_eplb": True, diff --git a/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py b/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py index 9d5b78f0..8ec25cbb 100644 --- a/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py +++ b/tests/e2e/nightly/models/test_deepseek_v3_2_exp_w8a8.py @@ -72,27 +72,13 @@ async def test_models(model: str, tp_size: int, dp_size: int, port = get_open_port() env_dict = {"HCCL_BUFFSIZE": "1024", "VLLM_ASCEND_ENABLE_MLAPO": "0"} server_args = [ - "--no-enable-prefix-caching", - "--enable-expert-parallel", + "--no-enable-prefix-caching", "--enable-expert-parallel", "--tensor-parallel-size", - str(tp_size), - "--data-parallel-size", - str(dp_size), - "--port", - str(port), - "--max-model-len", - "16384", - "--max-num-batched-tokens", - "16384", - "--block-size", - "16", - "--trust-remote-code", - "--quantization", - "ascend", - "--gpu-memory-utilization", - "0.9", - "--additional-config", - '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}', + str(tp_size), "--data-parallel-size", + str(dp_size), "--port", + str(port), "--max-model-len", "16384", "--max-num-batched-tokens", + "16384", "--block-size", "16", "--trust-remote-code", "--quantization", + "ascend", "--gpu-memory-utilization", "0.9" ] if full_graph: server_args += [ diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml deleted file mode 100644 index 6754bdc8..00000000 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-A2-torchair.yaml +++ /dev/null @@ -1,64 +0,0 @@ -test_name: "test DeepSeek-R1-W8A8 torchair on A2" -model: "vllm-ascend/DeepSeek-R1-0528-W8A8" -num_nodes: 2 -npu_per_node: 8 -env_common: - VLLM_USE_MODELSCOPE: true - HCCL_BUFFSIZE: 1024 - SERVER_PORT: 8080 - OMP_PROC_BIND: false - OMP_NUM_THREADS: 10 - - -deployment: - - - server_cmd: > - vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 - --host 0.0.0.0 - --port $SERVER_PORT - --data-parallel-size 4 - --data-parallel-size-local 2 - --data-parallel-address $LOCAL_IP - --data-parallel-rpc-port 13399 - --no-enable-prefix-caching - --max-num-seqs 16 - --tensor-parallel-size 4 - --max-model-len 36864 - --max-num-batched-tokens 6000 - --enable-expert-parallel - --trust-remote-code - --quantization ascend - --gpu-memory-utilization 0.9 - --speculative-config '{"num_speculative_tokens": 1, "method":"mtp"}' - --additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' - - - - server_cmd: > - vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 - --headless - --data-parallel-size 4 - --data-parallel-rpc-port 13399 - --data-parallel-size-local 2 - --data-parallel-start-rank 2 - --data-parallel-address $MASTER_IP - --no-enable-prefix-caching - --max-num-seqs 16 - --tensor-parallel-size 4 - --max-model-len 36864 - --max-num-batched-tokens 6000 - --enable-expert-parallel - --trust-remote-code - --quantization ascend - --gpu-memory-utilization 0.9 - --speculative-config '{"num_speculative_tokens": 1, "method":"mtp"}' - --additional-config '{"torchair_graph_config":{"enabled":true,"enable_multistream_moe":true},"chunked_prefill_for_mla":true,"enable_weight_nz_layout":true}' -benchmarks: - acc: - case_type: accuracy - dataset_path: vllm-ascend/gsm8k - request_conf: vllm_api_general_chat - dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt - max_out_len: 32768 - batch_size: 512 - baseline: 95 - threshold: 5 diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml index 961bb83f..0cf65ae2 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-EPLB.yaml @@ -58,7 +58,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' - server_cmd: > @@ -96,7 +96,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -135,7 +135,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -173,7 +173,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' + '{"multistream_overlap_shared_expert":true,"dynamic_eplb":true,"num_iterations_eplb_update":2048,"num_wait_worker_iterations":200}' benchmarks: perf: case_type: performance diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml index 37455738..966abb86 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8.yaml @@ -57,7 +57,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' + '{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' - server_cmd: > @@ -95,7 +95,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":false,"enable_multistream_shared_expert":false},"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' + '{"enable_prefill_optimizations":true,"enable_weight_nz_layout":true}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -134,7 +134,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}' + '{"multistream_overlap_shared_expert":true}' - server_cmd: > vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 @@ -172,7 +172,7 @@ deployment: } }' --additional-config - '{"torchair_graph_config":{"enabled":true,"enable_multistream_mla":true,"graph_batch_sizes":[28],"use_cached_graph":true,"enable_super_kernel":false},"multistream_overlap_shared_expert":true}' + '{"multistream_overlap_shared_expert":true}' benchmarks: perf: case_type: performance diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml index b9a584ed..09afa24c 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3.yaml @@ -82,7 +82,6 @@ deployment: --trust-remote-code --no-enable-prefix-caching --gpu-memory-utilization 0.9 - --additional-config '{"torchair_graph_config":{"enabled":true}}' --kv-transfer-config '{"kv_connector": "MooncakeConnector", "kv_role": "kv_consumer", diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml index 93e76ca5..77577f03 100644 --- a/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-V3_2-Exp-bf16.yaml @@ -29,7 +29,6 @@ deployment: --trust-remote-code --no-enable-prefix-caching --gpu-memory-utilization 0.9 - --additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' - server_cmd: > @@ -49,5 +48,4 @@ deployment: --trust-remote-code --no-enable-prefix-caching --gpu-memory-utilization 0.92 - --additional-config '{"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]}}' benchmarks: diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py deleted file mode 100644 index ddaeeab9..00000000 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -import pytest -from vllm import SamplingParams -from vllm.config import CompilationConfig, CUDAGraphMode - -from tests.e2e.conftest import VllmRunner - - -@pytest.fixture -def sampling_config(): - return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False) - - -@pytest.fixture -def model_name(): - return "wemaster/deepseek_mtp_main_random_bf16" - - -def mtp_torchair_correctness( - sampling_config: SamplingParams, - model_name: str, - graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, -): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - ''' - Compare the outputs of a original LLM and a speculative LLM - should be the same when using mtp speculative decoding. - ''' - with VllmRunner(model_name, - tensor_parallel_size=1, - gpu_memory_utilization=0.7, - max_model_len=256, - enforce_eager=False, - additional_config={ - "torchair_graph_config": { - "enabled": True, - "use_cached_graph": False, - "graph_batch_sizes": [1, 2, 4], - }, - "multistream_overlap_shared_expert": "True" - }) as ref_llm: - ref_outputs = ref_llm.generate(example_prompts, sampling_config) - - graph_mode_str = "PIECEWISE" - if graph_mode == CUDAGraphMode.FULL: - graph_mode_str = "FULL" - - with VllmRunner(model_name, - tensor_parallel_size=1, - max_num_seqs=256, - gpu_memory_utilization=0.7, - distributed_executor_backend="mp", - enable_expert_parallel=True, - speculative_config={ - "method": "mtp", - "num_speculative_tokens": 1, - }, - enforce_eager=False, - max_model_len=2000, - compilation_config=CompilationConfig( - cudagraph_mode=graph_mode_str), - additional_config={ - "torchair_graph_config": { - "enabled": True, - "use_cached_graph": False, - "graph_batch_sizes": [1, 2, 4], - }, - "multistream_overlap_shared_expert": "True" - }) as spec_llm: - spec_outputs = spec_llm.generate(example_prompts, sampling_config) - - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - ref_token_ids = ref_output[0][0] - spec_token_ids = spec_output[0][0] - if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output[1][0]}") - print(f"spec_output: {spec_output[1][0]}") - - # Heuristic: expect at least 66% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.66 * len(ref_outputs)) - - -def test_mtp_torchair_correctness_piecewise( - sampling_config: SamplingParams, - model_name: str, -): - mtp_torchair_correctness(sampling_config, model_name) - - -def test_mtp_torchair_correctness_full( - sampling_config: SamplingParams, - model_name: str, -): - mtp_torchair_correctness(sampling_config, model_name, CUDAGraphMode.FULL) diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index 7756071b..8faa3bb2 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -118,7 +118,6 @@ def mock_dist_env(mocker: MockerFixture): return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config', return_value=MagicMock( - torchair_graph_config=MagicMock(enabled=False), enable_multistream_moe=False, expert_map_path=None )), \ diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index f5d4f663..f3e263ff 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -110,11 +110,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase): def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, mock_custom_enabled, mock_soc_version, mock__c): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - - # Setup mock for custom kernel path - mock__c.rotary_embedding.return_value = self.query, self.key vllm_config = VllmConfig() model_config = ModelConfig(MODEL, @@ -139,9 +134,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_contiguous(self, mock_npu_rotary, mock_custom_enabled): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - # Test contiguous path when custom is disabled non_contig_query = self.query.transpose(0, 1) non_contig_key = self.key.transpose(0, 1) @@ -165,9 +157,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_with_offsets(self): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - # Test that NotImplementedError is raised when offsets is provided offsets = torch.tensor([1, 2, 3]) with self.assertRaises(NotImplementedError): @@ -190,9 +179,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, mock_custom_enabled): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - # Test neox_style override vllm_config = VllmConfig() model_config = ModelConfig(MODEL, @@ -219,9 +205,6 @@ class TestAscendRotaryEmbedding(unittest.TestCase): @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) def test_rope_forward_oot_rotary_dim_less_than_head_size( self, mock_npu_rotary, mock_custom_enabled): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - # test case when rotary_dim < head_size org_rotary_dim = self.layer.rotary_dim self.layer.rotary_dim = self.layer.head_size // 2 @@ -415,7 +398,6 @@ class TestAscendMRotaryEmbedding(unittest.TestCase): mrope_section=self.mrope_section) self.mock_config = MagicMock() - self.mock_config.torchair_graph_config.enabled = False def _create_vllm_config(self): vllm_config = VllmConfig() diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py index 76d510dd..ba7bf6ef 100644 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -33,7 +33,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase): mock_get_ep_group.return_value = mock_ep_group mock_ascend_config = Mock() - mock_ascend_config.torchair_graph_config = Mock(enabled=False) mock_ascend_config.enable_chunked_prefill = False mock_get_ascend_config.return_value = mock_ascend_config mock_mc2_group = Mock(device_group=0) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index ac33ae15..a92bbc80 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -13,15 +13,10 @@ # This file is a part of the vllm-ascend project. # -import os - -from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.config import VllmConfig from tests.ut.base import TestBase -from vllm_ascend.ascend_config import (_check_torchair_supported, - check_ascend_config, - clear_ascend_config, get_ascend_config, +from vllm_ascend.ascend_config import (clear_ascend_config, get_ascend_config, init_ascend_config) @@ -45,17 +40,6 @@ class TestAscendConfig(TestBase): self.assertIsNone(ascend_config.expert_map_path) self.assertFalse(ascend_config.multistream_overlap_shared_expert) - torchair_graph_config = ascend_config.torchair_graph_config - self.assertFalse(torchair_graph_config.enabled) - self.assertEqual(torchair_graph_config.mode, '') - self.assertFalse(torchair_graph_config.use_cached_graph) - self.assertEqual(torchair_graph_config.graph_batch_sizes, []) - self.assertFalse(torchair_graph_config.graph_batch_sizes_init) - self.assertFalse(torchair_graph_config.enable_multistream_mla) - self.assertTrue(torchair_graph_config.enable_view_optimize) - self.assertTrue(torchair_graph_config.enable_frozen_parameter) - self.assertFalse(torchair_graph_config.enable_kv_nz) - ascend_compilation_config = ascend_config.ascend_compilation_config self.assertTrue(ascend_compilation_config.enable_quantization_fusion) @@ -63,16 +47,6 @@ class TestAscendConfig(TestBase): def test_init_ascend_config_with_additional_config(self): test_vllm_config = VllmConfig() test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - "use_cached_graph": True, - "graph_batch_sizes": [1, 2, 4], - "graph_batch_sizes_init": False, - "enable_multistream_mla": True, - "enable_view_optimize": True, - "enable_frozen_parameter": True, - "enable_kv_nz": True - }, "ascend_compilation_config": { "enable_quantization_fusion": False, }, @@ -84,65 +58,9 @@ class TestAscendConfig(TestBase): self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") self.assertTrue(ascend_config.multistream_overlap_shared_expert) - torchair_graph_config = ascend_config.torchair_graph_config - self.assertTrue(torchair_graph_config.enabled) - self.assertTrue(torchair_graph_config.use_cached_graph) - self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4]) - self.assertFalse(torchair_graph_config.graph_batch_sizes_init) - self.assertTrue(torchair_graph_config.enable_multistream_mla) - self.assertTrue(torchair_graph_config.enable_view_optimize) - self.assertTrue(torchair_graph_config.enable_frozen_parameter) - self.assertTrue(torchair_graph_config.enable_kv_nz) ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.enable_quantization_fusion) - @_clean_up_ascend_config - def test_init_ascend_config_with_refresh(self): - test_vllm_config = VllmConfig() - ascend_config = init_ascend_config(test_vllm_config) - self.assertFalse(ascend_config.torchair_graph_config.enabled) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - } - ascend_config = init_ascend_config(test_vllm_config) - self.assertFalse(ascend_config.torchair_graph_config.enabled) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True, - } - ascend_config = init_ascend_config(test_vllm_config) - self.assertTrue(ascend_config.torchair_graph_config.enabled) - - @_clean_up_ascend_config - def test_init_ascend_config_with_wrong_input(self): - test_vllm_config = VllmConfig() - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - "graph_batch_sizes": "fake_size", - }, - "refresh": True, - } - with self.assertRaises(TypeError): - init_ascend_config(test_vllm_config) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "graph_batch_sizes": [1, 2, 4, 8], - "graph_batch_sizes_init": True, - }, - "refresh": True, - } - with self.assertRaises(ValueError): - init_ascend_config(test_vllm_config) - @_clean_up_ascend_config def test_get_ascend_config(self): test_vllm_config = VllmConfig() @@ -162,203 +80,3 @@ class TestAscendConfig(TestBase): clear_ascend_config() with self.assertRaises(RuntimeError): get_ascend_config() - - @_clean_up_ascend_config - def test_check_ascend_config_pass(self): - test_vllm_config = VllmConfig() - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - @_clean_up_ascend_config - def test_check_ascend_config_wrong_case(self): - test_vllm_config = VllmConfig() - - # torchair + eager mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - enforce_eager = True - check_ascend_config(test_vllm_config, enforce_eager) - # torchair + non deepseek model - with self.assertRaises(NotImplementedError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "refresh": True - } - model_path = os.path.join(os.path.dirname(__file__), "fake_weight") - fake_model_config = ModelConfig(model=model_path) - fake_model_config.hf_config = PretrainedConfig() - fake_model_config.hf_config.model_type = "llama" - test_vllm_config.model_config = fake_model_config - init_ascend_config(test_vllm_config) - check_ascend_config(test_vllm_config, False) - - def test_check_torchair_supported(self): - test_cases = [('deepseek_v3', True), ('PanguProMoE', True), - ('qwen', True), ('llama', False)] - for model_type, expected_output in test_cases: - self.assertEqual(_check_torchair_supported(model_type), - expected_output) - - @_clean_up_ascend_config - def test_ascend_config_load_error(self): - test_vllm_config = VllmConfig() - # graph_batch_sizes should be list. - with self.assertRaises(TypeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "graph_batch_sizes": "fake_size", - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # use_cached_graph should not be enabled without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "use_cached_graph": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # use_cached_kv_cache_bytes should not be enabled without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "use_cached_kv_cache_bytes": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # graph_batch_sizes should not be set without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "graph_batch_sizes": [1, 2, 4], - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - "use_cached_graph": False, - "use_cached_kv_cache_bytes": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # graph_batch_sizes_init should not be enabled without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "graph_batch_sizes_init": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # enable_multistream_mla should not be enabled without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "enable_multistream_mla": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # mode should not be configured without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "mode": 'max-autotune', - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - # enable_kv_nz should not be enabled without torchair graph mode - with self.assertRaises(RuntimeError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - "enable_kv_nz": True, - }, - "refresh": True - } - init_ascend_config(test_vllm_config) - - with self.assertRaises(AssertionError): - test_vllm_config.additional_config = { - "lmhead_tensor_parallel_size": 2, - "refresh": True - } - test_vllm_config.parallel_config = ParallelConfig( - data_parallel_size=4, tensor_parallel_size=2) - init_ascend_config(test_vllm_config) - - with self.assertRaises(AssertionError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": True, - }, - "oproj_tensor_parallel_size": 2, - "refresh": True - } - test_vllm_config.parallel_config = ParallelConfig( - data_parallel_size=4, tensor_parallel_size=2) - init_ascend_config(test_vllm_config) - - with self.assertRaises(AssertionError): - test_vllm_config.additional_config = { - "torchair_graph_config": { - "enabled": False, - }, - "oproj_tensor_parallel_size": 2, - "refresh": True - } - test_vllm_config.parallel_config = ParallelConfig( - data_parallel_size=4, tensor_parallel_size=1) - model_path = os.path.join(os.path.dirname(__file__), "fake_weight") - test_vllm_config.model_config = ModelConfig(model=model_path, - enforce_eager=True) - init_ascend_config(test_vllm_config) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index b6752317..230ebf00 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -31,7 +31,6 @@ class TestNPUPlatform(TestBase): @staticmethod def mock_vllm_ascend_config(): mock_ascend_config = MagicMock() - mock_ascend_config.torchair_graph_config.enabled = False mock_ascend_config.xlite_graph_config.enabled = False mock_ascend_config.enable_shared_expert_dp = False return mock_ascend_config @@ -403,47 +402,6 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._910_93) - @patch("vllm_ascend.utils.update_default_aclgraph_sizes") - @patch("vllm_ascend.ascend_config.check_ascend_config") - @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch( - "vllm_ascend.core.recompute_schedule_config.RecomputeSchedulerConfig.initialize_from_config" - ) - def test_check_and_update_config_torchair_enabled_compilation( - self, mock_init_recompute, mock_init_ascend, mock_check_ascend, - mock_update_default, mock_soc_version): - mock_update_default.return_value = MagicMock() - mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() - mock_ascend_config.torchair_graph_config.enabled = True - mock_init_ascend.return_value = mock_ascend_config - vllm_config = TestNPUPlatform.mock_vllm_config() - vllm_config.model_config.enforce_eager = False - vllm_config.parallel_config.decode_context_parallel_size = 1 - vllm_config.parallel_config.prefill_context_parallel_size = 1 - vllm_config.parallel_config.tensor_parallel_size = 1 - mock_init_recompute.return_value = MagicMock() - vllm_config.scheduler_config = MagicMock() - - vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE - - with self.assertLogs(logger="vllm", level="INFO") as cm: - from vllm_ascend import platform - - importlib.reload(platform) - self.platform.check_and_update_config(vllm_config) - self.assertTrue("Torchair compilation enabled" in cm.output[0]) - - self.assertEqual( - vllm_config.compilation_config.mode, - CompilationMode.NONE, - ) - self.assertEqual( - vllm_config.compilation_config.cudagraph_mode, - CUDAGraphMode.NONE, - ) - @patch('vllm_ascend.utils.get_ascend_device_type', return_value=AscendDeviceType._910_93) @patch("vllm_ascend.ascend_config.check_ascend_config") @@ -503,16 +461,6 @@ class TestNPUPlatform(TestBase): "vllm_ascend.worker.worker_v1.NPUWorker", ) - test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() - test_ascend_config.torchair_graph_config.enabled = True - mock_init_ascend.return_value = test_ascend_config - vllm_config.parallel_config.worker_cls = "auto" - self.platform.check_and_update_config(vllm_config) - self.assertEqual( - vllm_config.parallel_config.worker_cls, - "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker", - ) - test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() test_ascend_config.xlite_graph_config.enabled = True mock_init_ascend.return_value = test_ascend_config @@ -550,14 +498,7 @@ class TestNPUPlatform(TestBase): self.platform.check_and_update_config(vllm_config) self.assertEqual(vllm_config.compilation_config.custom_ops, []) - @patch('vllm_ascend.platform.get_ascend_config') - def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - mock_config.enable_shared_expert_dp = False - - mock_get_ascend_config.return_value = mock_config - + def test_get_attn_backend_cls_use_v1_and_mla(self): result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, @@ -570,56 +511,7 @@ class TestNPUPlatform(TestBase): self.assertEqual(result, "vllm_ascend.attention.mla_v1.AscendMLABackend") - @patch('vllm_ascend.platform.get_ascend_config') - def test_get_attn_backend_cls_use_v1_mla_and_torchair( - self, mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = True - - mock_get_ascend_config.return_value = mock_config - - result = self.platform.get_attn_backend_cls( - selected_backend="ascend", - head_size=64, - dtype="float16", - kv_cache_dtype="float16", - block_size=64, - #use_sfa=False, - use_mla=True, - ) - self.assertEqual( - result, - "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend") - - @patch('vllm_ascend.platform.get_ascend_config') - def test_get_attn_backend_cls_use_v1_and_torchair(self, - mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = True - - mock_get_ascend_config.return_value = mock_config - - result = self.platform.get_attn_backend_cls( - selected_backend="ascend", - head_size=64, - dtype="float16", - kv_cache_dtype="float16", - block_size=64, - #use_sfa=False, - use_mla=False, - ) - self.assertEqual( - result, - "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend" - ) - - @patch('vllm_ascend.platform.get_ascend_config') - def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - - mock_get_ascend_config.return_value = mock_config - + def test_get_attn_backend_cls_use_v1_only(self): result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, diff --git a/tests/ut/torchair/__init__.py b/tests/ut/torchair/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/ut/torchair/models/test_qwen3_moe.py b/tests/ut/torchair/models/test_qwen3_moe.py deleted file mode 100644 index 77792ab5..00000000 --- a/tests/ut/torchair/models/test_qwen3_moe.py +++ /dev/null @@ -1,61 +0,0 @@ -from unittest.mock import Mock - -import pytest -from pytest_mock import MockerFixture -from transformers import PretrainedConfig -from vllm.distributed.parallel_state import GroupCoordinator - -from tests.ut.base import PytestBase -from vllm_ascend.torchair.models.qwen3_moe import CustomSparseMoeBlock - - -class TestCustomSparseMoeBlock(PytestBase): - - @pytest.fixture - def setup_csmb(self, mocker: MockerFixture): - config = PretrainedConfig(num_experts=64, - hidden_size=2048, - num_experts_per_tok=2, - moe_intermediate_size=1408, - norm_topk_prob=True) - mocker.patch( - 'vllm_ascend.torchair.models.qwen3_moe.get_tensor_model_parallel_world_size', - return_value=10) - mocker.patch( - 'vllm.model_executor.layers.linear.ReplicatedLinear.__init__', - return_value=None) - mocker.patch( - 'vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__init__', - return_value=None) - - tp_group = Mock(spec=GroupCoordinator) - tp_group.rank_in_group = 0 - tp_group.world_size = 1 - tp_group.device_group = Mock() - - dp_group = Mock(spec=GroupCoordinator) - dp_group.rank_in_group = 0 - dp_group.world_size = 1 - - ep_group = Mock(spec=GroupCoordinator) - ep_group.rank_in_group = 0 - ep_group.world_size = 1 - - mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_tp_group', - return_value=tp_group) - mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_dp_group', - return_value=dp_group) - mocker.patch('vllm_ascend.torchair.models.qwen3_moe.get_ep_group', - return_value=ep_group) - ascend_config = mocker.MagicMock() - ascend_config.max_num_batched_tokens = 2048 - ascend_config.max_model_len = 1024 - mocker.patch("vllm_ascend.utils.get_ascend_config", - return_value=ascend_config) - - custom_moe_block = CustomSparseMoeBlock(config, None, "") - return custom_moe_block - - def test_init(self, mocker: MockerFixture, setup_csmb): - custom_moe_block = setup_csmb - assert isinstance(custom_moe_block, CustomSparseMoeBlock) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py deleted file mode 100644 index 59f83742..00000000 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ /dev/null @@ -1,206 +0,0 @@ -import pytest -import torch -from pytest_mock import MockerFixture -from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig - -from tests.ut.base import PytestBase -from vllm_ascend.torchair.models.torchair_deepseek_mtp import ( - TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor, - TorchairDeepSeekMultiTokenPredictorLayer) - - -class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase): - - @pytest.fixture - def setup_mtp_layer(self, mocker: MockerFixture): - config = PretrainedConfig(vocab_size=1000, - hidden_size=768, - rms_norm_eps=1e-5) - mocker.patch( - 'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size', - return_value=1) - mocker.patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", - return_value=None) - mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__", - return_value=None) - mocker.patch( - "vllm.model_executor.models.deepseek_mtp.SharedHead.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__", - return_value=None) - mocker_deepseek_v2_decode_layer = mocker.patch( - "vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", - return_value=None) - ascend_config = mocker.MagicMock() - ascend_config.max_num_batched_tokens = 2048 - ascend_config.max_model_len = 1024 - mocker.patch("vllm_ascend.utils.get_ascend_config", - return_value=ascend_config) - - mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None) - mocker_deepseek_v2_decode_layer.assert_called_once() - return mtp_layer - - def test_init(self, mocker: MockerFixture, setup_mtp_layer): - mtp_layer = setup_mtp_layer - assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer) - - def test_forward(self, mocker: MockerFixture, setup_mtp_layer): - mtp_layer = setup_mtp_layer - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch.object(mtp_layer, - 'eh_proj', - return_value=torch.randn(2, 3, 768)) - mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768)) - mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768), - torch.randn(2, 3, 768)) - mtp_layer.enorm.return_value = torch.randn(2, 3, 768) - mtp_layer.hnorm.return_value = torch.randn(2, 3, 768) - - input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) - positions = torch.tensor([[0, 1, 2], [0, 1, 2]]) - kv_cache = torch.randn(2, 3, 768) - previous_hidden_states = torch.randn(2, 3, 768) - inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]]) - - output = mtp_layer(input_ids, positions, kv_cache, None, - previous_hidden_states, inputs_embeds, 0) - assert output.shape == (3, 768) - - -class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): - - @pytest.fixture - def setup_predictor(self, mocker: MockerFixture): - mock_vllm_config = mocker.MagicMock(spec=VllmConfig) - mock_model_config = mocker.MagicMock(spec=ModelConfig) - mock_hf_config = mocker.MagicMock() - mock_hf_config.num_hidden_layers = 12 - mock_hf_config.num_nextn_predict_layers = 3 - mock_hf_config.vocab_size = 30000 - mock_model_config.hf_config = mock_hf_config - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = CacheConfig() - mock_vllm_config.quant_config = mocker.MagicMock() - mocker.patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", - return_value=None) - ascend_config = mocker.MagicMock() - ascend_config.max_num_batched_tokens = 2048 - ascend_config.max_model_len = 1024 - mocker.patch("vllm_ascend.utils.get_ascend_config", - return_value=ascend_config) - - predictor = TorchairDeepSeekMultiTokenPredictor( - vllm_config=mock_vllm_config) - return predictor - - def test_init(self, mocker: MockerFixture, setup_predictor): - predictor = setup_predictor - assert predictor.num_mtp_layers == 3 - assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor) - - @pytest.mark.parametrize( - 'kv_caches, inputs_embeds', - [(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))]) - def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches, - inputs_embeds): - predictor = setup_predictor - mock_layer = mocker.MagicMock() - mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0]) - predictor.layers_list = [mock_layer] - - # todo: need or not? - # predictor.num_mtp_layers = 1 - input_ids = torch.tensor([[1, 2, 3]]) - positions = torch.tensor([[0, 1, 2]]) - mocker.patch( - "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__", - return_value=torch.tensor([[1.0, 2.0, 3.0]])) - output = predictor.forward(input_ids, positions, kv_caches, None, None, - inputs_embeds, 0) - mock_layer.assert_called_once() - assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0])) - - def test_compute_logits(self, mocker: MockerFixture, setup_predictor): - hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]]) - predictor = setup_predictor - - mock_layer = mocker.MagicMock() - mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0]) - predictor.layers_list = [mock_layer] - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__", - return_value=None) - predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0]) - - result_logits = predictor.compute_logits(hidden_states=hidden_states) - predictor.logits_processor.assert_called_once() - assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0])) - - -class TestTorchairDeepSeekMTP(PytestBase): - - @pytest.fixture - def setup_mtp(self, mocker: MockerFixture): - vllm_config = mocker.MagicMock() - vllm_config.model_config.hf_config.num_hidden_layers = 12 - vllm_config.model_config.hf_config.num_nextn_predict_layers = 3 - vllm_config.cache_config = mocker.MagicMock() - vllm_config.quant_config = mocker.MagicMock() - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__", - return_value=None) - mocker.patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", - return_value=None) - ascend_config = mocker.MagicMock() - ascend_config.max_num_batched_tokens = 2048 - ascend_config.max_model_len = 1024 - mocker.patch("vllm_ascend.utils.get_ascend_config", - return_value=ascend_config) - - mtp = TorchairDeepSeekMTP(vllm_config=vllm_config) - return mtp - - def test_init(self, mocker: MockerFixture, setup_mtp): - mtp = setup_mtp - assert isinstance(mtp, TorchairDeepSeekMTP) - - def test_forward(self, mocker: MockerFixture, setup_mtp): - input_ids = torch.tensor([[1, 2, 3]]) - positions = torch.tensor([[0, 1, 2]]) - kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])] - previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]]) - inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]]) - spec_step_idx = 0 - setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]]) - - output = setup_mtp.forward(input_ids, positions, kv_caches, None, - previous_hidden_states, inputs_embeds, - spec_step_idx) - assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py deleted file mode 100644 index eb425670..00000000 --- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py +++ /dev/null @@ -1,366 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -from types import SimpleNamespace -from unittest.mock import MagicMock, Mock, patch - -import pytest -import torch -from transformers import PretrainedConfig -from vllm.config import CacheConfig -from vllm.distributed.parallel_state import GroupCoordinator -from vllm.transformers_utils.config import patch_rope_parameters - -from vllm_ascend.torchair.models.torchair_deepseek_v2 import ( - TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM, - TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention, - TorchairDeepseekV2MLP, TorchairDeepseekV2MoE, - TorchairDeepseekV2RowParallelLinear, - TorchairDeepseekV2RowParallelLinearReplaceAllreduce, - TorchairDeepseekV2SiluAndMul) - - -@pytest.fixture -def base_config(): - config = PretrainedConfig( - hidden_size=128, - num_attention_heads=8, - num_hidden_layers=2, - intermediate_size=256, - hidden_act="silu", - rms_norm_eps=1e-6, - rope_theta=10000.0, - max_position_embeddings=2048, - n_routed_experts=4, - n_shared_experts=1, - moe_intermediate_size=256, - num_experts_per_tok=2, - routed_scaling_factor=1.0, - first_k_dense_replace=0, - moe_layer_freq=1, - kv_lora_rank=16, - qk_nope_head_dim=16, - qk_rope_head_dim=16, - v_head_dim=32, - topk_method="noaux_tc", - scoring_func="softmax", - norm_topk_prob=True, - n_group=1, - topk_group=1, - vocab_size=10000, - ) - patch_rope_parameters(config) - return config - - -@pytest.fixture -def vllm_config(base_config): - model_config = SimpleNamespace( - hf_config=base_config, - tensor_parallel_size=1, - dtype=torch.float32, - use_mla=False, - quant_config=None, - max_model_len=2048, - ) - - cache_config = CacheConfig() - vllm_config = Mock() - vllm_config.model_config = model_config - vllm_config.cache_config = cache_config - vllm_config.quant_config = None - return vllm_config - - -@pytest.fixture -def mock_distributed(): - tp_group = Mock(spec=GroupCoordinator) - tp_group.rank_in_group = 0 - tp_group.world_size = 1 - tp_group.device_group = Mock() - - dp_group = Mock(spec=GroupCoordinator) - dp_group.rank_in_group = 0 - dp_group.world_size = 1 - - ep_group = Mock(spec=GroupCoordinator) - ep_group.rank_in_group = 0 - ep_group.world_size = 1 - - pp_group = Mock(spec=GroupCoordinator) - pp_group.rank_in_group = 0 - pp_group.world_size = 1 - - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - - mlp_tp_group = Mock(spec=GroupCoordinator) - mlp_tp_group.rank_in_group = 0 - mlp_tp_group.world_size = 1 - mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128)) - - mock_vllm_config = Mock() - mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) - mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) - - with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \ - patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \ - patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \ - patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \ - patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \ - patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \ - patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", - return_value=Mock(is_first_rank=False, is_last_rank=False)), \ - patch('vllm.distributed.parallel_state.get_dcp_group', return_value=dcp_group), \ - patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)), \ - patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1),\ - patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ - patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, - _PP=pp_group), \ - patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group): - yield - - -@pytest.fixture -def mock_forward_context(): - forward_context = Mock(in_profile_run=False, with_prefill=False) - with patch( - "vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context", - return_value=forward_context): - yield - - -@pytest.fixture -def patch_attention_init(): - try: - from vllm_ascend.torchair.models.torchair_deepseek_v2 import \ - DeepseekV2Attention - original_init = DeepseekV2Attention.__init__ - - def patched_init(self, *args, **kwargs): - kwargs.pop("decoder_layer", None) - if 'vllm_config' not in kwargs: - mock_vllm_config = Mock() - mock_vllm_config.model_config = Mock() - mock_vllm_config.model_config.hf_config = Mock() - mock_vllm_config.model_config.hf_config.hidden_size = 128 - mock_vllm_config.model_config.dtype = torch.float32 - mock_vllm_config.model_config.quant_config = None - mock_vllm_config.cache_config = CacheConfig() - kwargs['vllm_config'] = mock_vllm_config - return original_init(self, *args, **kwargs) - - DeepseekV2Attention.__init__ = patched_init - yield - DeepseekV2Attention.__init__ = original_init - except ImportError: - yield - - -def test_torchair_deepseek_v2_silu_and_mul(): - torch.set_default_device("cpu") - - silu = TorchairDeepseekV2SiluAndMul() - assert silu.weight_scale is None - - x = torch.randn(2, 4) - output = silu.forward_oot(x) - assert output.shape == (2, 2) - - weight_scale = Mock(return_value=torch.tensor(0.1)) - silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale) - quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32) - dynamic_scale = torch.randn(2, 1) - with patch("torch_npu.npu_dequant_swiglu_quant", - return_value=torch.randn(2, 4)): - output = silu.forward_oot((quant_x, dynamic_scale)) - assert output.shape == (2, 4) - - -def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed): - linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128, - output_sizes=[64, 64], - bias=False, - quant_config=None) - assert linear.output_sizes == [64, 64] - - param = Mock() - param.data = torch.zeros(128, 128) - param.output_dim = 1 - param.is_gguf_weight = False - param.is_gguf_weight_type = False - loaded_weight = torch.randn(128, 64) - linear.weight_loader(param, loaded_weight, loaded_shard_id=0) - - with pytest.raises(AssertionError): - linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0) - - -@pytest.mark.parametrize("cls", [ - TorchairDeepseekV2RowParallelLinearReplaceAllreduce, - TorchairDeepseekV2RowParallelLinear -]) -def test_row_parallel_linear(cls, mock_distributed, mock_forward_context): - linear = cls(input_size=128, output_size=64, bias=False, quant_config=None) - linear.quant_method = Mock() - linear.quant_method.apply.return_value = torch.randn(2, 4, 64) - - input_ = torch.randn(2, 4, 128) - with patch( - "vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim", - return_value=[torch.randn(2, 4, 64)]): - linear.input_is_parallel = False - output = linear(input_, is_prefill=True) - assert output[0].shape == (2, 4, 64) - - linear.input_is_parallel = True - output = linear(input_, is_prefill=False) - assert output[0].shape == (2, 4, 64) - - -def test_torchair_deepseek_v2_mlp(mock_distributed, base_config): - mlp = TorchairDeepseekV2MLP(hidden_size=128, - intermediate_size=256, - hidden_act="silu", - quant_config=None) - assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul) - with patch( - "vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig" - ) as mock_quant_config: - mock_quant_config.name = "w8a8dynamic" - with pytest.raises(NotImplementedError): - TorchairDeepseekV2MLP(hidden_size=128, - intermediate_size=256, - hidden_act="silu", - quant_config=mock_quant_config, - force_replicate=False) - with pytest.raises(ValueError): - TorchairDeepseekV2MLP(hidden_size=128, - intermediate_size=256, - hidden_act="relu", - quant_config=None) - - -def test_torchair_deepseek_v2_moe(mock_distributed, base_config, - mock_forward_context): - base_config.n_shared_experts = 1 - moe = TorchairDeepseekV2MoE(config=base_config, - quant_config=None, - prefix="mlp") - assert moe.top_k == 2 - - x = torch.randn(2, 4, 128) - attn_metadata = Mock(num_prefills=1) - with patch( - "vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__", - return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))): - output = moe(x, attn_metadata) - assert output.shape == (2, 4, 128) - - -@patch("torch_npu.npu_rms_norm") -def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, - base_config): - mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) - - attn = TorchairDeepseekV2MLAAttention(config=base_config, - hidden_size=128, - num_heads=8, - qk_nope_head_dim=16, - qk_rope_head_dim=16, - v_head_dim=32, - q_lora_rank=16, - kv_lora_rank=16, - cache_config=CacheConfig(), - quant_config=None, - prefix="layers.0.self_attn") - assert attn.debug_layer_idx == 0 - - x = torch.randn(2, 4, 128) - positions = torch.arange(4).repeat(2, 1) - with patch.object(attn.mla_attn, - "__call__", - return_value=torch.randn(2, 4, 128)): - with pytest.raises(AssertionError): - attn(positions, x) - - attn = TorchairDeepseekV2MLAAttention(config=base_config, - hidden_size=128, - num_heads=8, - qk_nope_head_dim=16, - qk_rope_head_dim=16, - v_head_dim=32, - q_lora_rank=None, - kv_lora_rank=16, - prefix="layers.1.self_attn") - assert hasattr(attn, "q_proj") - - -@patch("torch_npu.npu_add_rms_norm") -@patch("torch_npu.npu_rms_norm") -@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) -def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done, - mock_rms_norm, mock_add_norm, - mock_distributed, base_config, - vllm_config, mock_forward_context, - patch_attention_init): - mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) - mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128), - torch.randn(2, 128)) - base_config.n_routed_experts = 4 - layer = TorchairDeepseekV2DecoderLayer( - config=base_config, - prefix="layers.0", - model_config=vllm_config.model_config, - cache_config=CacheConfig(), - quant_config=None) - assert isinstance(layer.mlp, TorchairDeepseekV2MoE) - - x = torch.randn(2, 4, 128) - positions = torch.arange(4).repeat(2, 1) - - with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \ - patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))): - hidden_states, residual = layer(positions, x, None) - assert hidden_states.shape == (2, 4, 128) - - base_config.n_routed_experts = None - layer = TorchairDeepseekV2DecoderLayer( - config=base_config, - prefix="layers.0", - model_config=vllm_config.model_config, - quant_config=None) - assert isinstance(layer.mlp, TorchairDeepseekV2MLP) - - -def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config, - patch_attention_init): - model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config) - - input_ids = torch.randint(0, 10000, (2, 4)) - positions = torch.arange(4).repeat(2, 1) - with patch.object(model.model, - "forward", - return_value=torch.randn(2, 4, 128)): - output = model(input_ids, positions) - assert output.shape == (2, 4, 128) - - weights = [("model.embed_tokens.weight", torch.randn(10000, 128))] - with patch( - "vllm.model_executor.model_loader.weight_utils.default_weight_loader" - ): - loaded = model.load_weights(weights) - assert loaded is not None diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py deleted file mode 100644 index f0782e75..00000000 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ /dev/null @@ -1,423 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -from typing import List, TypedDict -from unittest.mock import MagicMock, patch - -import pytest -import torch -import torch.nn as nn -import torch_npu -from pytest_mock import MockerFixture -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase - -import vllm_ascend -from vllm_ascend.ascend_forward_context import get_fused_moe_state -from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod -from vllm_ascend.torchair.ops.torchair_fused_moe import ( - TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod) -from vllm_ascend.utils import adapt_patch # noqa E402 -from vllm_ascend.utils import AscendDeviceType - -adapt_patch(True) - - -def mock_ep_and_mc2_group(mocker): - mock_group = mocker.MagicMock() - mock_group.rank_in_group = 0 - mock_group.rank = 0 - mock_group.world_size = 4 - mock_group.device_group = "mock_group_ep" - mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8)) - return mock_group - - -def mock_dp_and_tp_group(mocker): - mock_group = mocker.MagicMock() - mock_group.rank_in_group = 0 - mock_group.world_size = 2 - mock_group.device_group = "mock_group" - mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32)) - return mock_group - - -@pytest.fixture -def mock_dist_env(mocker: MockerFixture): - # init dist env patch - dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) - - with patch('torch.npu.is_available', return_value=True), \ - patch('torch.distributed.get_rank', return_value=0), \ - patch('torch.distributed.get_world_size', return_value=4), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \ - patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce', - return_value=torch.randn(5, 32)), \ - patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', - return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config', - return_value=MagicMock( - torchair_graph_config=MagicMock(enabled=False), - enable_multistream_moe=False, - enable_shared_expert_dp=False, - expert_map_path=None, - init_redundancy_expert=2, - )), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map', - return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context', - return_value=MagicMock( - max_tokens_across_dp=10, - dp_metadata=dp_metadata, - )), \ - patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config', - return_value=MagicMock( - parallel_config=MagicMock(tensor_parallel_size=2), - scheduler_config=MagicMock(max_num_seqs=4), - model_config=MagicMock(max_model_len=2048) - )): - yield - - -@pytest.fixture -def mock_moe_env(mocker: MockerFixture): - # init moe env patch - - with patch('torch_npu.npu_moe_gating_top_k', return_value=( - torch.randn(8, 2), - torch.randint(0, 8, (8, 2)), - None - )), \ - patch('torch_npu.npu_moe_init_routing', return_value=( - torch.randn(8, 2), - torch.randint(0, 8, (8, 2)), - torch.tensor([0, 1, 2, 4, 6, 2, 7, 1]) - )), \ - patch("torch_npu.npu_moe_compute_expert_tokens", return_value=( - torch.randn(8, 2) - )), \ - patch("torch_npu.npu_moe_distribute_dispatch", return_value=( - torch.randn(16, 2) - )), \ - patch("torch_npu.npu_moe_distribute_combine", return_value=( - torch.randn(16, 2) - )), \ - patch("torch_npu.npu_grouped_matmul", return_value=( - [torch.randn(16, 2)] - )), \ - patch("torch_npu.npu_swiglu", return_value=( - torch.randn(16, 2) - )), \ - patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=( - torch.randn(8, 2), - torch.randint(0, 8, (8, 2)), - torch.tensor([0, 1, 2, 4, 6, 2, 7, 1]) - )), \ - patch("torch_npu.npu_moe_finalize_routing", return_value=( - torch.randn(16, 2) - )): - if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'): - with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=( - torch.randn(16, 2))), \ - patch("torch_npu.npu_moe_distribute_combine_v2", return_value=( - torch.randn(16, 2))): - yield - else: - yield - - -@pytest.fixture -def default_moe_config(): - """default moe config""" - return { - 'num_experts': 8, - 'top_k': 2, - 'hidden_size': 512, - 'intermediate_size': 1024 - } - - -@pytest.fixture -def moe_method(mock_dist_env): - moe = MagicMock() - moe.moe_parallel_config.return_value = MagicMock(ep_size=4) - moe.moe_parallel_config.use_ep = False - moe.moe_parallel_config.dp_size = 1 - return TorchairAscendUnquantizedFusedMoEMethod(moe) - - -class Device(TypedDict): - device_id: int - device_expert: List[int] - - -class Layer(TypedDict): - layer_id: int - device_count: int - device_list: List[Device] - - -class MockData(TypedDict): - moe_layer_count: int - layer_list: List[Layer] - - -class MockQuantMethod(nn.Module): - - def __init__(self, shared_experts, num_tokens): - super().__init__() - if shared_experts: - self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32), - torch.randn(num_tokens, 10))) - else: - self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32))) - - -class MockFusedMoEMethod(FusedMoEMethodBase): - moe = MagicMock() - - def __init__(self): - super().__init__(self.moe) - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - pass - - def apply(self, hidden_states: torch.Tensor, - expert_weights: torch.Tensor) -> torch.Tensor: - pass - - def get_fused_moe_quant_config(self, layer: torch.nn.Module): - pass - - -class TestTorchairAscendFusedMoe: - - @pytest.fixture - def test_init_no_quant(self, mock_dist_env, default_moe_config): - layer = TorchairAscendFusedMoE(**default_moe_config) - - layer.w13_weight = nn.Parameter( - torch.randn(default_moe_config['num_experts'], - default_moe_config['intermediate_size'] * 2, - default_moe_config['hidden_size'])) - layer.w2_weight = nn.Parameter( - torch.randn(default_moe_config['num_experts'], - default_moe_config['hidden_size'], - default_moe_config['intermediate_size'])) - - assert layer.num_experts == default_moe_config['num_experts'] - assert layer.top_k == default_moe_config['top_k'] - assert hasattr(layer, 'w13_weight') - assert hasattr(layer, 'w2_weight') - - # check group_topk - with pytest.raises(AssertionError): - error_config = default_moe_config.copy() - error_config['use_grouped_topk'] = True - layer = TorchairAscendFusedMoE(**error_config) - - # check scoring_func - with pytest.raises(ValueError): - error_config = default_moe_config.copy() - error_config['scoring_func'] = "random" - layer = TorchairAscendFusedMoE(**error_config) - - @pytest.fixture - def test_init_with_quant(self, mock_dist_env, default_moe_config): - mock_quant_config = MagicMock() - mock_quant_method = MockFusedMoEMethod() - mock_quant_config.get_quant_method.return_value = mock_quant_method - mock_quant_config.is_layer_skipped_ascend.return_value = False - with patch("vllm_ascend.quantization.quant_config.get_quant_method"): - moe = TorchairAscendFusedMoE(**default_moe_config, - quant_config=mock_quant_config) - assert moe.quant_method is not None - assert isinstance(moe.quant_method, AscendFusedMoEMethod) - - @pytest.fixture - def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config): - mock_quant_config = MagicMock() - mock_quant_method = MockFusedMoEMethod() - mock_quant_config.get_quant_method.return_value = mock_quant_method - mock_quant_config.is_layer_skipped_ascend.return_value = True - - moe = TorchairAscendFusedMoE(**default_moe_config, - quant_config=mock_quant_config) - - assert moe.quant_method is not None - assert isinstance(moe.quant_method, - TorchairAscendUnquantizedFusedMoEMethod) - - @pytest.fixture - @pytest.mark.parametrize( - "others_param", - [[None, - MagicMock(return_value=torch.randn(5, 32)), False, 5, None], - [2, None, False, 5, None], [None, None, True, 5, None], - [None, None, False, 1, None], [None, None, True, 5, 1], - [None, None, False, 5, 1]]) - def test_forward(self, mock_dist_env, default_moe_config, others_param): - """ - 1 test has shared_experts - 2 test has top_k - 3 test is_prefill is true - 4 test single num_tokens(decode) - 5 test ep_size is 1 and is_prefill is true - 6 test ep_size is 1 and is_prefill is False - """ - top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param - inputs = torch.randn(num_tokens, 32) - router_logits = torch.randn(num_tokens, 8) - moe = TorchairAscendFusedMoE(**default_moe_config) - - if ep_size == 1: - moe.moe_parallel_config.ep_size = 1 - - moe.quant_method = MockQuantMethod(shared_experts, num_tokens) - forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens, - dtype=torch.bool), - padded_num_tokens=num_tokens) - with patch( - "vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", - return_value=forward_context): - output = moe.forward(inputs, - router_logits, - is_prefill=is_prefill, - top_k=top_k, - shared_experts=shared_experts) - - moe.quant_method.apply.assert_called_once() - - if shared_experts: - assert output[0].shape == (num_tokens, 32) - assert output[1].shape == (num_tokens, 10) - else: - assert output.shape == (num_tokens, 32) - - @pytest.fixture - def test_forward_ms_fused_moe_comp(self, mock_dist_env, - default_moe_config): - inputs = torch.randn(5, 32) - router_logits = torch.randn(5, 8) - moe = TorchairAscendFusedMoE(**default_moe_config) - - moe.quant_method = MockQuantMethod(None, 5) - output = moe._forward_ms_fused_moe_comp(inputs, - router_logits, - is_prefill=False, - real_top_k=1) - - moe.quant_method.apply.assert_called_once() - - assert output.shape == (5, 32) - - -class TestTorchairAscendUnquantizedFusedMoEMethod: - - def test_process_weights_after_loading(self, moe_method, mock_dist_env): - layer = MagicMock() - layer.w13_weight.data = torch.randn(16, 32) - layer.w2_weight.data = torch.randn(16, 32) - - moe_method.process_weights_after_loading(layer) - - assert isinstance(layer.w13_weight, torch.nn.Parameter) - assert isinstance(layer.w2_weight, torch.nn.Parameter) - assert not layer.w13_weight.requires_grad - assert not layer.w2_weight.requires_grad - - @pytest.mark.parametrize("others_param", - [[256, 4], [128, 1], [128, 1], [128, 4]]) - def test_apply_without_expert_map(self, moe_method, mock_dist_env, - mock_moe_env, others_param): - """ - 1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all - 2 test use_select_experts and fused_experts - 3 test use select_gating_topk_softmax_experts and fused_experts - 4 test use select_experts and fused_experts_with_all2all_buffer - """ - global_num_experts, ep_size = others_param - is_prefill = False - global_redundant_expert_num = vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config( - ).init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - forward_context = MagicMock(fused_moe_state=get_fused_moe_state( - ep_size, is_prefill, is_deepseek_v3_r1)) - with patch( - "vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", - return_value=forward_context): - moe_method.ep_size = ep_size - x = torch.randn(8, 2, 2) - router_logits = torch.randn(8, 8) - layer = MagicMock() - layer.w13_weight = torch.randn(8, 16, 1) - layer.w2_weight = torch.randn(16, 8, 1) - result = moe_method.apply(layer=layer, - x=x, - router_logits=router_logits, - top_k=2, - renormalize=True, - global_num_experts=global_num_experts, - is_prefill=is_prefill) - - if ep_size == 1: - assert result.shape == (16, 2) - else: - assert result.shape == x.shape - - @pytest.mark.parametrize("others_param", [16, 1, 4]) - def test_apply_with_expert_map(self, moe_method, mock_dist_env, - mock_moe_env, others_param): - """ - 1 test use_select_experts and use fused_expters_with_mc2 - 2 test use_select_experts and fused_experts_with_all2all_buffer - 3 test use_select_experts and fused_experts_with_all2all - 4 test use_select_experts and fused_experts - """ - ep_size = others_param - is_prefill = False - forward_context = MagicMock( - fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True)) - with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \ - patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_device_type", return_value=AscendDeviceType._910_93): - expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) - moe_method.ep_size = ep_size - x = torch.randn(8, 2, 2) - if ep_size == 1: - x = x.view(-1, 2) - router_logits = torch.randn(8, 8) - layer = MagicMock() - layer.w13_weight = torch.randn(8, 16, 1) - layer.w2_weight = torch.randn(16, 8, 1) - result = moe_method.apply(layer=layer, - x=x, - router_logits=router_logits, - top_k=2, - renormalize=True, - global_num_experts=128, - expert_map=expert_map, - is_prefill=is_prefill) - - if ep_size == 16 or ep_size == 1: - assert result.shape == (16, 2) - else: - assert result.shape == x.shape diff --git a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py deleted file mode 100644 index 73a78b77..00000000 --- a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py +++ /dev/null @@ -1,333 +0,0 @@ -import math -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( - _set_cos_sin_cache, custom_rotary_embedding_enabled, - native_rope_deepseek_forward, rope_forward_oot, rotate_half, - yarn_find_correction_dim, yarn_get_mscale) -from vllm_ascend.utils import AscendDeviceType - - -class TestCustomRotaryEmbeddingEnabled(TestBase): - - def setUp(self): - # Common setup for tests - self.positions = torch.tensor([1, 2, 3]) - self.query = torch.randn(3, 4, dtype=torch.float16) - self.key = torch.randn(3, 4, dtype=torch.float16) - self.head_size = 32 - self.cos_sin_cache = torch.randn(3, 4) - - # Mock self object for rope_forward_oot - self.mock_self = MagicMock() - self.mock_self.head_size = self.head_size - self.mock_self.cos_sin_cache = self.cos_sin_cache - self.mock_self.is_neox_style = True - self.mock_self.forward_native.return_value = (self.query, self.key) - - def test_custom_rotary_embedding_enabled(self): - # Test when all conditions are True - with patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', - return_value=True): - result = custom_rotary_embedding_enabled(self.query, True, - self.head_size) - self.assertTrue(result) - - # Test when dtype is not float16 - with patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', - return_value=True): - query = self.query.to(torch.float32) - result = custom_rotary_embedding_enabled(query, True, - self.head_size) - self.assertFalse(result) - - # Test when neox_style is False - with patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', - return_value=True): - result = custom_rotary_embedding_enabled(self.query, False, - self.head_size) - self.assertFalse(result) - - # Test when head_size is not divisible by 32 - with patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', - return_value=True): - result = custom_rotary_embedding_enabled(self.query, True, - self.head_size + 1) - self.assertFalse(result) - - # Test when custom op is disabled - with patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op', - return_value=False): - result = custom_rotary_embedding_enabled(self.query, True, - self.head_size) - self.assertFalse(result) - - -class TestRopeForwardOot(TestBase): - - def setUp(self): - # Common setup for tests - self.positions = torch.tensor([1, 2, 3]) - self.query = torch.randn(3, 4, dtype=torch.float16) - self.key = torch.randn(3, 4, dtype=torch.float16) - self.head_size = 32 - self.cos_sin_cache = torch.randn(3, 4) - - # Mock self object for rope_forward_oot - self.mock_self = MagicMock() - self.mock_self.head_size = self.head_size - self.mock_self.cos_sin_cache = self.cos_sin_cache - self.mock_self.is_neox_style = True - self.mock_self.forward_native.return_value = (self.query, self.key) - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') - def test_rope_forward_oot_torchair_enabled_base(self, - mock_get_ascend_config): - # Setup mock for torchair enabled - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = True - mock_get_ascend_config.return_value = mock_config - - result_q, result_k = rope_forward_oot(self.mock_self, self.positions, - self.query, self.key) - - self.mock_self.forward_native.assert_called_once_with( - self.positions, self.query, self.key, None) - self.assertTrue(torch.equal(result_q, self.query)) - self.assertTrue(torch.equal(result_k, self.key)) - - @patch('torch.ops._C_ascend') - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._910_93) - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled', - return_value=True) - @patch('torch.ops._npu_rotary_embedding') - def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding, - mock_custom_enabled, - mock_soc_version, - mock_get_ascend_config, mock__c): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config - - # Setup mock for custom kernel path - - mock__c.rotary_embedding.return_value = self.query, self.key - - result_q, result_k = rope_forward_oot(self.mock_self, self.positions, - self.query, self.key) - - self.assertEqual(result_q.shape, self.query.shape) - self.assertEqual(result_k.shape, self.key.shape) - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled', - return_value=False) - @patch('torch_npu._npu_rotary_embedding') - def test_rope_forward_oot_contiguous(self, mock_npu_rotary, - mock_custom_enabled, - mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config - - # Test contiguous path when custom is disabled - non_contig_query = self.query.transpose(0, 1) - non_contig_key = self.key.transpose(0, 1) - - result_q, result_k = rope_forward_oot(self.mock_self, self.positions, - non_contig_query, non_contig_key) - - mock_npu_rotary.assert_called_once() - self.assertEqual(result_q.shape, non_contig_query.shape) - self.assertEqual(result_k.shape, non_contig_key.shape) - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') - def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config - - # Test that NotImplementedError is raised when offsets is provided - offsets = torch.tensor([1, 2, 3]) - with self.assertRaises(NotImplementedError): - rope_forward_oot(self.mock_self, self.positions, self.query, - self.key, offsets) - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled', - return_value=False) - @patch('torch_npu._npu_rotary_embedding') - def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary, - mock_custom_enabled, - mock_get_ascend_config): - mock_config = MagicMock() - mock_config.torchair_graph_config.enabled = False - mock_get_ascend_config.return_value = mock_config - - # Test neox_style override - result_q, result_k = rope_forward_oot(self.mock_self, - self.positions, - self.query, - self.key, - is_neox_style_override=False) - - # Check that neox_style=False was passed to the NPU function - args, kwargs = mock_npu_rotary.call_args - self.assertFalse(args[-1]) - - -class MockRopeModule: - - def __init__(self, max_seq_len=2048, is_neox_style=True): - self.max_seq_len = max_seq_len - self.is_neox_style = is_neox_style - self.cos_cached = None - self.sin_cached = None - self.rotary_dim = 1 - self.base = 1 - self.beta_fast = 32 - self.beta_slow = 1 - self.max_position_embeddings = 4096 - self.mscale = 1.0 - self.scaling_factor = 40 - - def register_buffer(self): - pass - - -class TestSetSinCosCache(TestBase): - - def test_set_cos_sin_cache(self): - module = MockRopeModule() - - with patch.object(module, "register_buffer") as mock_register_buffer: - _set_cos_sin_cache(module, - 1024, - device="cpu", - dtype=torch.bfloat16) - - mock_register_buffer.assert_called() - - -class TestNativeRopeDeepseekForward(TestBase): - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot): - module = MockRopeModule() - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, - key) - - assert q_pe.shape == query.shape - assert k_pe.shape == key.shape - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_key_reshaping( - self, mock_rope_forward_oot): - module = MockRopeModule() - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, - key) - - assert q_pe.shape == query.shape - assert k_pe.shape == (1, 128) - - @patch( - 'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot') - def test_native_rope_deepseek_forward_non_neox_style( - self, mock_rope_forward_oot): - module = MockRopeModule(is_neox_style=False) - positions = torch.tensor([1, 2, 3]) - query = torch.randn(1, 8, 128) - key = torch.randn(1, 8, 128) - - mock_rope_forward_oot.return_value = (query, key) - - q_pe, k_pe = native_rope_deepseek_forward(module, positions, query, - key) - - assert q_pe.shape == query.shape - assert k_pe.shape == key.shape - - -class TestRotateHalf(TestBase): - - def test_rotate_half_even_dim(self): - # Test with even dimension - x = torch.tensor([1.0, 2.0, 3.0, 4.0]) - expected = torch.tensor([-3.0, -4.0, 1.0, 2.0]) - result = rotate_half(x) - self.assertTrue(torch.allclose(result, expected)) - - -class TestYarnFindCorrectionDim(TestBase): - - def test_basic_case(self): - # Test with standard values - num_rotations = 100 - dim = 512 - base = 10000 - max_position_embeddings = 2048 - - result = yarn_find_correction_dim(num_rotations, dim, base, - max_position_embeddings) - - # Calculate expected value manually - expected = (dim * torch.log( - torch.tensor(max_position_embeddings) / - (num_rotations * 2 * torch.pi))) / (2 * - torch.log(torch.tensor(base))) - - self.assertTrue(torch.allclose(result, expected)) - - -class TestYarnGetMscale(TestBase): - - def test_scale_less_than_or_equal_1(self): - self.assertEqual(yarn_get_mscale(scale=0.5), 1.0) - self.assertEqual(yarn_get_mscale(scale=1.0), 1.0) - self.assertEqual(yarn_get_mscale(scale=0.999), 1.0) - - def test_scale_greater_than_1(self): - test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)), - (10.0, 1.0, 1.0 + 0.1 * math.log(10.0)), - (5.0, 2.0, 1.0 + 0.2 * math.log(5.0)), - (math.e, 1.0, 1.0 + 0.1)] - - for scale, mscale, expected in test_cases: - result = yarn_get_mscale(scale, mscale) - self.assertAlmostEqual( - result, - expected, - places=6, - msg=f"Failed for scale={scale}, mscale={mscale}") diff --git a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py deleted file mode 100644 index f29cafc6..00000000 --- a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py +++ /dev/null @@ -1,296 +0,0 @@ -from unittest.mock import Mock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( - TorchairAscendW4A8DynamicFusedMoEMethod, - TorchairAscendW4A8DynamicLinearMethod) - - -class TestAscendW4A8DynamicLinearMethod(TestBase): - - @patch('vllm.distributed.get_tensor_model_parallel_world_size') - @patch( - 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config' - ) - def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size): - mock_get_tp_world_size.return_value = 1 - mock_vllm_config = Mock() - mock_vllm_config.quant_config = Mock( - quant_description={"group_size": 256}) - mock_get_current_vllm_config.return_value = mock_vllm_config - self.method = TorchairAscendW4A8DynamicLinearMethod() - self.method.group_size = 8 - - def test_get_weight(self): - weight = self.method.get_weight(8, 32, torch.bfloat16) - self.assertEqual(weight["weight"].dtype, torch.int8) - self.assertEqual(weight["weight"].shape, (32, 8)) - # new quant version weight - self.method.new_quant_version = True - weight = self.method.get_weight(8, 32, torch.bfloat16) - self.assertEqual(weight["weight"].dtype, torch.int8) - self.assertEqual(weight["weight"].shape, (16, 8)) - self.assertEqual(weight["_packed_dim"], 0) - self.assertEqual(weight["_packed_factor"], 2) - - def test_get_pergroup_param(self): - params = self.method.get_pergroup_param(8, 32, torch.bfloat16) - self.assertEqual(params["weight_scale"].dtype, torch.bfloat16) - self.assertEqual(params["weight_scale"].shape, (32, 1)) - self.assertEqual(params["weight_offset"].dtype, torch.bfloat16) - self.assertEqual(params["weight_offset"].shape, (32, 1)) - self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16) - self.assertEqual(params["weight_scale_second"].shape, (32, 1)) - self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16) - self.assertEqual(params["weight_offset_second"].shape, (32, 1)) - # new quant version weight - self.method.new_quant_version = True - params = self.method.get_pergroup_param(8, - 32, - torch.bfloat16, - layer_type="column") - self.assertEqual(params["scale_bias"].dtype, torch.float32) - self.assertEqual(params["scale_bias"].shape, (32, 1)) - params = self.method.get_pergroup_param(8, - 32, - torch.bfloat16, - layer_type="row") - self.assertEqual(params["scale_bias"].dtype, torch.float32) - self.assertEqual(params["scale_bias"].shape, (32, 16)) - - @patch('torch_npu.npu_convert_weight_to_int4pack') - @patch('torch.Tensor.npu') - def test_process_weights_after_loading(self, mock_npu, - mock_npu_convert_weight): - mock_npu.side_effect = lambda: torch.zeros( - (1, 32), dtype=torch.float32) - mock_npu_convert_weight.return_value = torch.zeros((32, 4), - dtype=torch.int32) - # old quant version weight - layer = torch.nn.Module() - layer.weight = torch.nn.Parameter(torch.zeros((32, 8), - dtype=torch.int8), - requires_grad=False) - layer.weight_scale = torch.nn.Parameter(torch.ones( - (32, 1), dtype=torch.float32), - requires_grad=False) - layer.weight_offset = torch.nn.Parameter(torch.empty_like( - layer.weight_scale.data), - requires_grad=False) - layer.weight_scale_second = torch.nn.Parameter(torch.ones( - (32, 1), dtype=torch.float32), - requires_grad=False) - layer.weight_offset_second = torch.nn.Parameter(torch.empty_like( - layer.weight_scale_second.data), - requires_grad=False) - self.method.process_weights_after_loading(layer) - self.assertTrue(hasattr(layer, "weight_scale_bias")) - self.assertEqual(layer.weight_scale_bias.data.shape, (32, )) - self.assertEqual(layer.weight_scale_bias.data.dtype, torch.float32) - # new quant version weight - self.method.new_quant_version = True - new_layer = torch.nn.Module() - new_layer.weight = torch.nn.Parameter(torch.zeros((16, 8), - dtype=torch.int8), - requires_grad=False) - new_layer.weight_scale = torch.nn.Parameter(torch.ones( - (32, 1), dtype=torch.float32), - requires_grad=False) - new_layer.weight_offset = torch.nn.Parameter(torch.empty_like( - new_layer.weight_scale.data), - requires_grad=False) - new_layer.weight_scale_second = torch.nn.Parameter(torch.ones( - (32, 1), dtype=torch.float32), - requires_grad=False) - new_layer.weight_offset_second = torch.nn.Parameter( - torch.empty_like(new_layer.weight_scale_second.data), - requires_grad=False) - new_layer.scale_bias = torch.nn.Parameter(torch.zeros( - (32, 1), dtype=torch.float32), - requires_grad=False) - self.method.process_weights_after_loading(new_layer) - self.assertEqual(new_layer.scale_bias.data.shape, (32, )) - self.assertTrue(hasattr(new_layer, "weight_scale_second")) - self.assertEqual(new_layer.weight_scale_second.data.shape, (1, 32)) - - -class TestAscendW4A8DynamicFusedMoEMethod(TestBase): - experts = 8 - input_size = 16 - output_size = 56 - group_size = 2 - - @patch( - 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config' - ) - @patch( - 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group') - @patch( - 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ascend_config' - ) - @patch( - 'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group' - ) - @patch('torch.distributed.get_rank', return_value=0) - def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config, - mock_get_ep_group, get_current_vllm_config): - mock_ascend_config = Mock() - mock_ascend_config.torchair_graph_config = Mock(enabled=False) - mock_get_ascend_config.return_value = mock_ascend_config - mock_vllm_config = Mock() - mock_vllm_config.quant_config = Mock(quant_description={ - "group_size": self.group_size, - "version": "0.0.0" - }) - mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True) - get_current_vllm_config.return_value = mock_vllm_config - self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod() - - def test_get_weight(self): - # old quant version w4a8 weight - param_dict = self.quant_method.get_weight(self.experts, - self.input_size, - self.output_size, - torch.bfloat16) - self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) - self.assertEqual(param_dict["w13_weight"].shape, - (self.experts, 2 * self.input_size, self.output_size)) - # new quant version weight - self.quant_method.new_quant_version = True - param_dict = self.quant_method.get_weight(self.experts, - self.input_size, - self.output_size, - torch.bfloat16) - self.assertEqual(param_dict["w13_weight"].dtype, torch.int8) - self.assertEqual(param_dict["w13_weight"].shape, - (self.experts, self.input_size, self.output_size)) - - def test_get_dynamic_quant_param(self): - # old quant version weight - param_dict = self.quant_method.get_dynamic_quant_param( - self.experts, self.input_size, self.output_size, torch.bfloat16) - self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32) - self.assertEqual(param_dict["w13_weight_scale"].shape, - (self.experts, 2 * self.input_size, 1)) - self.assertEqual(param_dict["w13_weight_scale_second"].dtype, - torch.float32) - self.assertEqual(param_dict["w13_weight_scale_second"].shape, - (self.experts, 2 * self.input_size, - self.output_size // self.group_size)) - self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32) - self.assertEqual(param_dict["w2_weight_scale"].shape, - (self.experts, self.output_size, 1)) - self.assertEqual(param_dict["w2_weight_scale_second"].dtype, - torch.float32) - self.assertEqual(param_dict["w2_weight_scale_second"].shape, - (self.experts, self.output_size, - self.input_size // self.group_size)) - # new quant version weight - self.quant_method.new_quant_version = True - param_dict = self.quant_method.get_dynamic_quant_param( - self.experts, self.input_size, self.output_size, torch.bfloat16) - self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32) - self.assertEqual( - param_dict["w2_scale_bias"].shape, - (self.experts, self.output_size, 16 // self.quant_method.tp_size)) - # per-channel weight - self.quant_method.is_per_channel_weight = True - param_dict = self.quant_method.get_dynamic_quant_param( - self.experts, self.input_size, self.output_size, torch.bfloat16) - pergroup_param = [ - "w13_weight_scale_second", "w13_weight_offset_second", - "w2_weight_scale_second", "w2_weight_offset_second" - ] - is_contains = any(key in param_dict for key in pergroup_param) - self.assertFalse(is_contains) - - def build_layer(self, - is_new_quant_version=True, - is_per_channel_weight=False): - layer = torch.nn.Module() - if is_new_quant_version: - layer.w13_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.input_size, self.output_size), - dtype=torch.int8), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.output_size // 2, self.input_size), - dtype=torch.int8), - requires_grad=False) - w13_scale_bias = torch.zeros( - (self.experts, 2 * self.input_size, 1), dtype=torch.float32) - layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, - requires_grad=False) - w2_scale_bias = torch.zeros((self.experts, self.output_size, - 16 // self.quant_method.tp_size), - dtype=torch.float32) - layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, - requires_grad=False) - else: - layer.w13_weight = torch.nn.Parameter(torch.zeros( - (self.experts, 2 * self.input_size, self.output_size), - dtype=torch.int8), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.output_size, self.input_size), - dtype=torch.int8), - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, 1), dtype=torch.float32), - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, 1), dtype=torch.float32), - requires_grad=False) - if not is_per_channel_weight: - layer.w13_weight_scale_second = torch.nn.Parameter( - torch.ones((self.experts, 2 * self.input_size, - self.output_size // self.group_size), - dtype=torch.float32), - requires_grad=False) - layer.w13_weight_offset_second = torch.nn.Parameter( - torch.empty_like(layer.w13_weight_scale_second.data), - requires_grad=False) - layer.w2_weight_scale_second = torch.nn.Parameter( - torch.ones((self.experts, self.output_size, - self.input_size // self.group_size), - dtype=torch.float32), - requires_grad=False) - layer.w2_weight_offset_second = torch.nn.Parameter( - torch.empty_like(layer.w2_weight_scale_second.data), - requires_grad=False) - return layer - - @patch('torch_npu.npu_quantize') - @patch('torch.Tensor.npu') - def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): - mock_npu.return_value = torch.Tensor() - mock_npu_quantize.return_value = torch.Tensor() - # old quant version weight - layer = self.build_layer(is_new_quant_version=False) - self.quant_method.process_weights_after_loading(layer) - self.assertTrue(hasattr(layer, "w13_scale_bias")) - self.assertEqual(layer.w13_scale_bias.data.shape, - (self.experts, 2 * self.input_size)) - self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32) - self.assertTrue(hasattr(layer, "w2_scale_bias")) - self.assertEqual(layer.w2_scale_bias.data.shape, - (self.experts, self.output_size)) - self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) - # new quant version weight - self.quant_method.new_quant_version = True - new_layer = self.build_layer(is_new_quant_version=True) - self.quant_method.process_weights_after_loading(new_layer) - self.assertEqual(new_layer.w13_scale_bias.data.shape, - (self.experts, 2 * self.input_size)) - self.assertEqual(new_layer.w2_scale_bias.data.shape, - (self.experts, self.output_size)) - self.assertFalse(hasattr(new_layer, "w13_weight_scale_second")) - # per-channel weight - self.quant_method.is_per_channel_weight = True - per_channel_layer = self.build_layer(is_new_quant_version=True, - is_per_channel_weight=True) - self.quant_method.process_weights_after_loading(per_channel_layer) - self.assertEqual(new_layer.w13_scale_bias.data.shape, - (self.experts, 2 * self.input_size)) diff --git a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py deleted file mode 100644 index 11ad00a2..00000000 --- a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py +++ /dev/null @@ -1,129 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( - torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) -from vllm_ascend.utils import AscendDeviceType - - -class TestAscendW8A8FusedMoEMethod(TestBase): - - def setUp(self): - self.hidden_size = 128 - self.num_tokens = 128 - self.placeholder = torch.randn(self.num_tokens, - self.hidden_size, - dtype=torch.bfloat16) - - @patch("torch.distributed.all_to_all_single") - @patch("torch_npu.npu_moe_re_routing") - @patch("torch_npu.npu_grouped_matmul") - @patch("torch_npu.npu_swiglu") - @patch("torch_npu.npu_dynamic_quant") - @patch("torch_npu.npu_moe_finalize_routing") - @patch("torch_npu.npu_moe_init_routing_quant") - def test_torchair_fused_experts_with_all2all( - self, mock_npu_moe_init_routing_quant, mock_moe_finalize_routing, - mock_dynamic_quant, mock_swiglu, mock_grouped_matmul, - mock_moe_re_routing, mock_all_to_all_single): - - expert_map = MagicMock() - ep_group = MagicMock() - placeholder_int8 = torch.randint(0, - 100, - (self.num_tokens, self.hidden_size), - dtype=torch.int8) - placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32) - mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( - input) - mock_npu_moe_init_routing_quant.return_value = ( - placeholder_int8, placeholder_ones, placeholder_ones, - torch.bincount(placeholder_ones, minlength=len(expert_map)), - torch.randn(self.num_tokens)) - mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder, - torch.randint(0, - 100, - (self.num_tokens, ), - dtype=torch.int32), - self.placeholder) - mock_grouped_matmul.return_value = self.placeholder - mock_swiglu.return_value = self.placeholder - mock_dynamic_quant.return_value = ( - placeholder_int8, - torch.randn(self.num_tokens), - ) - mock_moe_finalize_routing.return_value = self.placeholder - - result = torchair_fused_experts_with_all2all( - hidden_states=self.placeholder, - w1=self.placeholder, - w1_scale=self.placeholder, - w2=self.placeholder, - w2_scale=self.placeholder, - topk_weights=self.placeholder, - topk_ids=self.placeholder, - top_k=8, - expert_map=expert_map, - ep_group=ep_group, - log2phy=None, - global_redundant_expert_num=256, - ) - self.assertIsNotNone(result) - self.assertEqual(result.dtype, torch.bfloat16) - self.assertEqual(result.shape, (128, 128)) - - @patch.dict('os.environ', { - 'HCCL_INTRA_ROCE_ENABLE': '0', - 'HCCL_INTRA_PCIE_ENABLE': '1' - }) - @patch( - "vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_device_type" - ) - @patch( - 'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group' - ) - @patch('torch_npu.npu_moe_distribute_combine_v2') - @patch('torch_npu.npu_moe_distribute_dispatch_v2') - @patch( - 'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode' - ) - def test_torchair_fused_experts_with_mc2_a2_optimization( - self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group, - mock_ascend_soc_version): - """Test expert_scales is passed in A2 SOC version with mc2 optimization""" - # Setup mocks - mock_ascend_soc_version.return_value = AscendDeviceType._910B - - mock_group = MagicMock() - mock_group.rank_in_group = 0 - mock_group.world_size = 4 - mock_get_group.return_value = mock_group - - mock_combine.return_value = self.placeholder - - mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1), - torch.randint(0, 32, (32, )), - torch.randint(1, 5, (8, )), - torch.randint(1, 5, (4, )), None, - torch.randn(32)) - mock_mlp_decode.return_value = self.placeholder - - result = torchair_fused_experts_with_mc2( - hidden_states=self.placeholder, - w1=self.placeholder, - w2=self.placeholder, - w1_scale=self.placeholder, - w2_scale=self.placeholder, - topk_weights=self.placeholder, - topk_ids=self.placeholder, - top_k=2, - mc2_mask=self.placeholder) - - # Check that expert_scales was passed to dispatch - call_args = mock_dispatch.call_args[1] - self.assertIn('expert_scales', call_args) - - self.assertIsInstance(result, torch.Tensor) - self.assertEqual(result.shape, self.placeholder.shape) diff --git a/tests/ut/torchair/test_torchair_attention.py b/tests/ut/torchair/test_torchair_attention.py deleted file mode 100644 index 0ee79d26..00000000 --- a/tests/ut/torchair/test_torchair_attention.py +++ /dev/null @@ -1,95 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch -from vllm.attention.backends.abstract import AttentionType -from vllm.distributed.parallel_state import GroupCoordinator - -from tests.ut.base import TestBase -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.torchair.torchair_attention import \ - AscendAttentionTorchairBackendImpl - - -class TestAscendAttentionTorchairBackendImpl(TestBase): - - @patch("torch.zeros") - @patch('vllm.distributed.parallel_state._TP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO - @patch("vllm.distributed.get_tensor_model_parallel_world_size", - return_value=2) # TODO - @patch("vllm.config.get_current_vllm_config") # TODO - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO - def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp, - mock_zeros): - mock_tp.world_size = 2 # TODO - ascend_config.torchair_graph_config.enabled = True # TODO - ascend_config.torchair_graph_config.enable_kv_nz = False # TODO - speculative_config = MagicMock() - speculative_config.num_speculative_tokens = 4 - vllm_config.speculative_config = speculative_config - - num_heads = 32 - head_size = 128 # TODO - scale = 0.1 # TODO - num_kv_heads = 4 - kv_cache_dtype = "auto" - attn_type = AttentionType.DECODER - mock_zeros.return_value = torch.ones((), - device='cpu', - dtype=torch.int32) - - self.impl = AscendAttentionTorchairBackendImpl( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype=kv_cache_dtype, - blocksparse_params=None, - logits_soft_cap=None, - attn_type=attn_type, - kv_sharing_target_layer_name=None) - - @patch("torch_npu.npu_scatter_nd_update_") - @patch("torch_npu.npu_fused_infer_attention_score") - def test_forward_with_decode_only(self, mock_fused, _): - layer = MagicMock() - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 - - seq_len = 1 - num_tokens = 100 - num_blocks = 256 - block_size = 4 - - query = torch.randn(num_tokens, seq_len, - self.impl.num_heads * self.impl.head_size) - key = torch.randn(num_tokens, seq_len, - self.impl.num_kv_heads * self.impl.head_size) - value = torch.randn(num_tokens, seq_len, - self.impl.num_kv_heads * self.impl.head_size) - kv_cache = (torch.randn(num_blocks, block_size, - self.impl.num_heads * self.impl.head_size), - torch.randn(num_blocks, block_size, - self.impl.num_heads * self.impl.head_size)) - output = torch.randn(num_tokens, self.impl.num_heads, - self.impl.head_size) - - decode = MagicMock() # TODO - decode.seq_lens_list = [2] * num_tokens - decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32) - decode.attn_mask = None - - metadata = MagicMock() - metadata.attn_state = AscendAttentionState.DecodeOnly - metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32) - metadata.decode = decode - - mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads, - self.impl.head_size), - torch.ones(1)) - - result = self.impl.forward(layer, query, key, value, kv_cache, - metadata, output) - self.assertEqual(result.shape[0], num_tokens) diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py deleted file mode 100644 index 52a4af37..00000000 --- a/tests/ut/torchair/test_torchair_mla.py +++ /dev/null @@ -1,887 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -import torch -from torch import nn -from vllm.distributed.parallel_state import GroupCoordinator -from vllm.model_executor.layers.linear import LinearBase - -from tests.ut.base import TestBase -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.torchair.torchair_mla import ( - AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata, - AscendMLATorchairImpl, AscendMLATorchairMetadata, - AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata) -from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata - - -class TestAscendMLATorchairBackend(TestBase): - - def test_get_name(self): - self.assertEqual(AscendMLATorchairBackend.get_name(), - "ASCEND_MLA_TORCHAIR") - - def test_get_builder_cls(self): - self.assertEqual(AscendMLATorchairBackend.get_builder_cls(), - AscendMLATorchairMetadataBuilder) - - def test_get_kv_cache_shape(self): - result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128) - self.assertEqual(result, (2, 4, 8, 128)) - - def test_get_impl_cls(self): - result = AscendMLATorchairBackend.get_impl_cls() - self.assertEqual(result, AscendMLATorchairImpl) - - -class TestAscendMLATorchairPrefillMetadata(TestBase): - - def test_ascend_mla_prefill_metadata_default(self): - attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool) - query_lens = [1, 2] - seq_lens = [2, 2] - context_lens = torch.tensor([1, 2]) - input_positions = torch.tensor([0, 1, 0, 1]) - query_start_loc = torch.tensor([0, 1, 3]) - block_table = torch.tensor([[0, 1], [2, 3]]) - max_query_len = 2 - max_seq_lens = 2 - - metadata = AscendMLATorchairPrefillMetadata( - attn_mask=attn_mask, - query_lens=query_lens, - seq_lens=seq_lens, - context_lens=context_lens, - input_positions=input_positions, - query_start_loc=query_start_loc, - block_table=block_table, - max_query_len=max_query_len, - max_seq_lens=max_seq_lens) - self.assertIs(metadata.attn_mask, attn_mask) - self.assertEqual(metadata.query_lens, query_lens) - self.assertEqual(metadata.seq_lens, seq_lens) - self.assertIs(metadata.context_lens, context_lens) - self.assertIs(metadata.input_positions, input_positions) - self.assertIs(metadata.query_start_loc, query_start_loc) - self.assertIs(metadata.block_table, block_table) - self.assertEqual(metadata.max_query_len, max_query_len) - self.assertEqual(metadata.max_seq_lens, max_seq_lens) - self.assertIsNone(metadata.chunked_context) - - def test_ascend_mla_prefill_metadata_with_chunked_context(self): - cu_seq_lens = torch.tensor([0, 2, 4]) - starts = torch.tensor([0, 2]) - seq_tot = [2, 2] - max_seq_lens = [2, 2] - workspace = torch.randn(2, 4) - chunk_seq_lens = torch.tensor([2, 2]) - - chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata( - cu_seq_lens=cu_seq_lens, - starts=starts, - seq_tot=seq_tot, - max_seq_lens=max_seq_lens, - workspace=workspace, - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens) - - metadata = AscendMLATorchairPrefillMetadata( - attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), - query_lens=[1, 2], - seq_lens=[2, 2], - context_lens=torch.tensor([1, 2]), - input_positions=torch.tensor([0, 1, 0, 1]), - query_start_loc=torch.tensor([0, 1, 3]), - block_table=torch.tensor([[0, 1], [2, 3]]), - max_query_len=2, - max_seq_lens=2, - chunked_context=chunked_context) - - self.assertIsNotNone(metadata.chunked_context) - self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens) - self.assertIs(metadata.chunked_context.starts, starts) - self.assertEqual(metadata.chunked_context.seq_tot, seq_tot) - self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) - self.assertIs(metadata.chunked_context.workspace, workspace) - self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) - self.assertIs(metadata.chunked_context.chunk_seq_lens_npu, - chunk_seq_lens) - - -class TestAscendMLATorchairDecodeMetadata(TestBase): - - def test_ascend_mla_decode_metadata_default(self): - input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) - block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]]) - seq_lens = torch.tensor([[2], [3]]) - max_seq_lens = 4 - seq_lens_list = [2, 3] - attn_mask = None - - metadata = AscendMLATorchairDecodeMetadata(input_positions, - block_table, seq_lens, - max_seq_lens, seq_lens_list, - attn_mask) - - self.assertIs(metadata.input_positions, input_positions) - self.assertIs(metadata.block_table, block_table) - self.assertIs(metadata.seq_lens, seq_lens) - self.assertEqual(metadata.max_seq_lens, max_seq_lens) - self.assertEqual(metadata.seq_lens_list, seq_lens_list) - self.assertIsNone(attn_mask) - - -class TestAscendMLATorchairMetadata(TestBase): - - def test_ascend_mla_metadata_default(self): - num_actual_tokens = 100 - slot_mapping = torch.randn(100, 4, 1024) - query_start_loc = torch.tensor([1, 2, 3, 4]) - seq_lens = [30, 50] - block_tables = torch.randint(0, 100, (100, 4)) - - num_decodes = 4 - num_decode_tokens = 8 - num_prefills = 8 - - num_input_tokens = 2 - - query_lens = None - head_dim = None - attn_mask = None - attn_state = AscendAttentionState.ChunkedPrefill - - decode = None - prefill = None - - metadata = AscendMLATorchairMetadata( - num_actual_tokens, slot_mapping, query_start_loc, seq_lens, - block_tables, num_decodes, num_decode_tokens, num_prefills, - num_input_tokens, query_lens, head_dim, attn_mask, attn_state, - decode, prefill) - - self.assertEqual(metadata.num_actual_tokens, num_actual_tokens) - self.assertIs(metadata.slot_mapping, slot_mapping) - self.assertIs(metadata.query_start_loc, query_start_loc) - self.assertEqual(metadata.seq_lens, seq_lens) - self.assertIs(metadata.block_tables, block_tables) - self.assertEqual(metadata.num_decodes, num_decodes) - self.assertEqual(metadata.num_decode_tokens, num_decode_tokens) - self.assertEqual(metadata.num_prefills, num_prefills) - self.assertEqual(metadata.num_input_tokens, num_input_tokens) - self.assertEqual(metadata.query_lens, query_lens) - self.assertEqual(metadata.head_dim, head_dim) - self.assertEqual(metadata.attn_mask, attn_mask) - self.assertEqual(metadata.attn_state, attn_state) - self.assertEqual(metadata.decode, decode) - self.assertEqual(metadata.prefill, prefill) - - -class TestAscendMLATorchairMetadataBuilder(TestBase): - - def test_ascend_mla_metadata_builder_default(self): - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - ascend_config = MagicMock() - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = True - with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", - return_value=ascend_config): - builder = AscendMLATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - self.assertEqual(builder.block_size, - mock_vllm_config.cache_config.block_size) - self.assertEqual( - builder.chunked_prefill_enabled, - mock_vllm_config.scheduler_config.enable_chunked_prefill) - self.assertEqual(builder.torchair_graph_enabled, True) - - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def test_reorder_batch_with_torchair_graph(self, ascend_config): - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - builder = AscendMLATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - input_batch = MagicMock() - input_batch.req_ids = [0, 1, 2, 3] - - scheduler_output = MagicMock() - scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1} - scheduler_output.scheduled_spec_decode_tokens = { - 0: [1], - 1: [], - 2: [1, 1], - 3: [] - } - - input_batch.swap_states = MagicMock() - - modified = builder.reorder_batch(input_batch, scheduler_output) - - self.assertFalse(modified) - input_batch.swap_states.assert_not_called() - - def test_reorder_batch_without_torchair_graph(self): - ascend_config = MagicMock() - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = False - - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", - return_value=ascend_config): - builder = AscendMLATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - input_batch = MagicMock() - input_batch.req_ids = [0, 1, 2, 3] - - scheduler_output = MagicMock() - scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2} - scheduler_output.scheduled_spec_decode_tokens = { - 0: [], - 1: [1], - 2: [], - 3: [] - } - - input_batch.swap_states = MagicMock() - - modified = builder.reorder_batch(input_batch, scheduler_output) - - self.assertTrue(modified) - input_batch.swap_states.assert_called_once_with(1, 2) - - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - builder = AscendMLATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 64) - self.assertTrue(torch.equal(result[:, :10], block_tables)) - - @pytest.mark.skip(reason="Skipping this test temporarily.") - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - - mock_model_config = MagicMock() - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - builder = AscendMLATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 4) - self.assertTrue(torch.equal(result, block_tables[:, :4])) - - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def test_get_graph_runner_block_tables_from_numpy(self, - mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - builder = AscendMLATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 64) - self.assertTrue(torch.equal(result[:, :10], block_tables)) - - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def test_build_dummy(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - builder = AscendMLATorchairMetadataBuilder( - None, - None, - mock_vllm_config, - mock_device, - metadata_cls=AscendMLATorchairMetadata) - builder.rope_dim = 64 - - with patch.object(builder, - "_get_graph_runner_block_tables", - side_effect=lambda x, y: y): - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=3, - num_actual_tokens=3, - decode_token_per_req=1, - actual_seq_lengths_q=[0, 1, 2], - attn_mask=torch.zeros((1, 1), dtype=torch.bool), - spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool), - ) - metadata = builder.build_torchair_graph_dummy(common_attn_metadata) - - sin_golden = torch.ones(3, - 1, - 1, - 64, - dtype=torch.float16, - device=mock_device) - cos_golden = torch.ones(3, - 1, - 1, - 64, - dtype=torch.float16, - device=mock_device) - - self.assertIsInstance(metadata, AscendMLATorchairMetadata) - self.assertEqual(metadata.num_input_tokens, 3) - self.assertEqual(metadata.num_actual_tokens, 3) - self.assertEqual(metadata.num_decodes, 1) - self.assertEqual(metadata.num_decode_tokens, 1) - self.assertEqual(metadata.num_prefills, 0) - self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly) - self.assertIsNone(metadata.prefill) - self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata) - self.assertEqual(metadata.block_tables.shape[0], 3) - self.assertEqual(metadata.block_tables.shape[1], 64) - self.assertEqual(metadata.seq_lens.shape[0], 3) - self.assertEqual(metadata.slot_mapping.shape[0], 3) - self.assertEqual(metadata.query_start_loc.shape[0], 3) - assert torch.equal(sin_golden, metadata.decode.sin) - assert torch.equal(cos_golden, metadata.decode.cos) - - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def test_build_decode(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - model = MagicMock(spec=nn.Module) - model.model = MagicMock(spec=nn.Module) - - builder = AscendMLATorchairMetadataBuilder( - None, - None, - mock_vllm_config, - mock_device, - metadata_cls=AscendMLATorchairMetadata) - builder.rope_dim = 64 - - builder.sin_cache = torch.tensor([10, 10]) - builder.cos_cache = torch.tensor([10, 10]) - - with patch.object(builder, - "_get_graph_runner_block_tables", - side_effect=lambda x, y: y): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=torch.tensor([0, 1, 2, 3]), - query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), - seq_lens_cpu=torch.tensor([1, 1, 1]), - num_reqs=3, - num_actual_tokens=3, - max_query_len=1, - decode_token_per_req=torch.tensor([1, 1, 1]), - block_table_tensor=torch.zeros((10, 10)), - slot_mapping=torch.tensor(range(20)), - actual_seq_lengths_q=torch.tensor([0, 1, 2]), - positions=torch.tensor([1, 1]), - attn_mask=torch.ones((15, 15)), - spec_attn_mask=None, - attn_state=AscendAttentionState.ChunkedPrefill, - num_computed_tokens_cpu=None, - seq_lens=None) - - metadata = builder.build(1, common_attn_metadata, model) - - self.assertIsInstance(metadata, AscendMLATorchairMetadata) - self.assertEqual(metadata.num_input_tokens, 0) - self.assertEqual(metadata.num_actual_tokens, 3) - self.assertEqual(metadata.num_decodes, 3) - self.assertEqual(metadata.num_decode_tokens, 3) - self.assertEqual(metadata.num_prefills, 0) - self.assertEqual(metadata.attn_state, - AscendAttentionState.ChunkedPrefill) - self.assertIsNone(metadata.prefill) - self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata) - self.assertEqual(metadata.block_tables.shape[0], 3) - self.assertEqual(metadata.block_tables.shape[1], 10) - self.assertEqual(metadata.seq_lens.shape[0], 3) - self.assertEqual(metadata.slot_mapping.shape[0], 3) - self.assertEqual(metadata.query_start_loc.shape[0], 4) - - -class TestAscendMLATorchairImpl(TestBase): - - @patch('vllm.distributed.parallel_state._TP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_tensor_model_parallel_world_size", - return_value=2) - @patch("vllm.config.get_current_vllm_config") - @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") - def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp): - mock_tp.world_size = 2 - ascend_config.torchair_graph_config.enabled = True - ascend_config.torchair_graph_config.enable_kv_nz = False - speculative_config = MagicMock() - speculative_config.num_speculative_tokens = 4 - vllm_config.speculative_config = speculative_config - - num_heads = 256 - head_size = 1024 - scale = 0.1 - num_kv_heads = 8 - kv_cache_dtype = "auto" - - kv_a_layernorm = MagicMock() - kv_a_layernorm.weight = torch.randn(96) - kv_a_layernorm.variance_epsilon = 1e-6 - kwargs = { - "q_lora_rank": 64, - "kv_lora_rank": 32, - "qk_nope_head_dim": 64, - "qk_rope_head_dim": 32, - "qk_head_dim": 96, - "v_head_dim": 128, - "rotary_emb": MagicMock(), - "q_proj": MagicMock(), - "q_b_proj": MagicMock(), - "kv_b_proj": MagicMock(), - "o_proj": MagicMock(), - "kv_a_proj_with_mqa": MagicMock(), - "kv_a_layernorm": kv_a_layernorm, - } - - self.impl = AscendMLATorchairImpl(num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype=kv_cache_dtype, - blocksparse_params=None, - logits_soft_cap=None, - attn_type=None, - kv_sharing_target_layer_name=None, - **kwargs) - - def test_init(self): - self.assertEqual(self.impl.num_heads, 256) - self.assertEqual(self.impl.head_size, 1024) - self.assertEqual(self.impl.scale, 0.1) - self.assertEqual(self.impl.num_kv_heads, 8) - self.assertEqual(self.impl.kv_cache_dtype, "auto") - self.assertEqual(self.impl.q_lora_rank, 64) - self.assertEqual(self.impl.kv_lora_rank, 32) - self.assertEqual(self.impl.qk_nope_head_dim, 64) - self.assertEqual(self.impl.qk_rope_head_dim, 32) - self.assertEqual(self.impl.qk_head_dim, 96) - self.assertEqual(self.impl.v_head_dim, 128) - self.assertIsNotNone(self.impl.rotary_emb) - self.assertIsNotNone(self.impl.q_proj) - self.assertIsNotNone(self.impl.kv_b_proj) - self.assertIsNotNone(self.impl.o_proj) - self.assertIsNotNone(self.impl.kv_a_proj_with_mqa) - self.assertIsNotNone(self.impl.kv_a_layernorm) - self.assertEqual(self.impl.num_queries_per_kv, 32) - self.assertEqual(self.impl.tp_size, 2) - self.assertTrue(self.impl.torchair_graph_enabled) - - def test_v_up_proj_and_o_proj(self): - batch_size = 4 - x = torch.randn(batch_size, self.impl.num_heads, - self.impl.kv_lora_rank) - - self.impl.o_proj.return_value = (torch.randn( - batch_size, self.impl.num_heads * self.impl.v_head_dim), ) - if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None: - self.impl.W_UV = torch.randn(self.impl.num_heads, - self.impl.kv_lora_rank, - self.impl.v_head_dim) - result = self.impl._v_up_proj_and_o_proj(x) - - self.assertEqual(result.shape[0], batch_size) - self.assertEqual(result.shape[1], - self.impl.num_heads * self.impl.v_head_dim) - - def test_q_proj_and_k_up_proj(self): - batch_size = 4 - x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim) - q_proj_output = torch.randn(batch_size, self.impl.num_heads, - self.impl.qk_head_dim) - self.impl.q_proj.return_value = (q_proj_output, ) - if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None: - self.impl.W_UK_T = torch.randn(self.impl.num_heads, - self.impl.qk_nope_head_dim, - self.impl.kv_lora_rank) - result = self.impl._q_proj_and_k_up_proj(x) - ql_nope, q_pe = result - self.assertEqual(ql_nope.shape[0], batch_size) - self.assertEqual(ql_nope.shape[1], self.impl.num_heads) - self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank) - self.assertEqual(q_pe.shape[0], batch_size) - self.assertEqual(q_pe.shape[1], self.impl.num_heads) - self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim) - - def test_process_weights_after_loading(self): - layer = MagicMock(spec=LinearBase) - layer.input_size_per_partition = 10 - quant_method = MagicMock() - apply = MagicMock() - quant_method.apply = apply - layer.quant_method = quant_method - shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim + - self.impl.v_head_dim) - shape_1 = self.impl.kv_lora_rank - layer.weight = torch.randn(shape_0, shape_1) - self.impl.kv_b_proj = layer - apply.return_value = layer.weight.T - self.impl.process_weights_after_loading(torch.bfloat16) - - self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads) - self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim) - self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank) - - self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads) - self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank) - self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim) - - def test_compute_prefill_context_none(self): - batch_size = 4 - kv_cache = torch.randn(10, 1, 1, 192) - query = torch.randn(batch_size, self.impl.num_heads, - self.impl.qk_head_dim) - metadata = MagicMock() - metadata.prefill = None - prefix_out = torch.randn(2, 16, 128) - prefix_lse = torch.randn(2, 16, 8) - out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, - metadata, prefix_out, - prefix_lse) - - self.assertTrue(torch.equal(prefix_out, out)) - self.assertTrue(torch.equal(prefix_lse, lse)) - - @patch("torch_npu.atb.npu_paged_cache_load") - @patch("torch_npu.atb.npu_ring_mla") - def test_compute_prefill_context(self, mock_ring, mock_load): - S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim - _, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim - latent_kv_dim = self.impl.kv_lora_rank - num_blocks, block_size = 100, 20 - query = torch.randn(S, N, D) - kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) - kv_cache_1 = torch.randn(num_blocks, block_size, N, D) - kv_cache = [kv_cache_0, kv_cache_1] - prefix_out = torch.randn(S, N, 128) - prefix_lse = torch.randn(S, N) - - self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), ) - - chunk_ctx = MagicMock() - chunk_ctx.seq_tot = [8] - chunk_ctx.chunk_seq_lens = [torch.tensor([8])] - chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])] - chunk_ctx.starts = [torch.tensor([0])] - - prefill_meta = MagicMock() - prefill_meta.chunked_context = chunk_ctx - prefill_meta.query_lens = [8] - prefill_meta.block_table = torch.randint(0, 100, (S, 4)) - - meta = MagicMock() - meta.prefill = prefill_meta - - out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, - meta, prefix_out, - prefix_lse) - - mock_load.assert_called_once() - mock_ring.assert_called_once() - - self.assertEqual(out.shape, prefix_out.shape) - self.assertEqual(lse.shape, prefix_lse.shape) - - @patch("torch_npu.npu_kv_rmsnorm_rope_cache") - def test_exec_kv(self, mock_kv_cache): - batch_size = 2 - hidden = torch.randn(batch_size, 128) - cos = torch.randn(batch_size, 32) - sin = torch.randn(batch_size, 32) - kv_cache = (torch.randn( - 4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), - torch.randn( - 4, 8, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)) - slots = torch.arange(batch_size, dtype=torch.long) - - proj_out = torch.randn( - batch_size, self.impl.num_kv_heads, 1, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim) - self.impl.kv_a_proj_with_mqa.return_value = (proj_out, ) - - mock_kv_cache.return_value = (torch.randn(batch_size, - self.impl.num_kv_heads, 1, - self.impl.qk_rope_head_dim), - torch.randn(batch_size, - self.impl.num_kv_heads, 1, - self.impl.kv_lora_rank), - None, None) - - k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots) - - self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden) - mock_kv_cache.assert_called_once() - self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1, - self.impl.qk_rope_head_dim)) - self.assertEqual( - k_nope.shape, - (batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank)) - self.assertEqual(kv.shape, - (batch_size, self.impl.num_kv_heads, 1, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)) - - @patch("torch_npu.npu_kv_rmsnorm_rope_cache") - def test_exec_kv_prefill(self, mock_kv): - B, N, S, H = 2, self.impl.num_kv_heads, 1, 128 - hidden_states = torch.randn(B, N, S, H) - cos = torch.randn(B, S, 32) - sin = torch.randn(B, S, 32) - kv_cache = ( - torch.randn(100, 8, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), - torch.randn(100, 8, - self.impl.kv_lora_rank + self.impl.qk_rope_head_dim), - ) - - slots = torch.arange(B * S, dtype=torch.long) - - proj_out = torch.randn( - B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim) - self.impl.kv_a_proj_with_mqa.return_value = (proj_out, ) - - mock_kv.return_value = (None, None, - torch.randn(B, self.impl.num_kv_heads, S, - self.impl.qk_rope_head_dim), - torch.randn(B, self.impl.num_kv_heads, S, - self.impl.kv_lora_rank)) - - k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin, - kv_cache, slots) - - self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states) - mock_kv.assert_called_once() - - self.assertEqual( - k_pe.shape, - (B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim)) - self.assertEqual( - k_nope.shape, - (B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank)) - - @patch("torch_npu.npu_interleave_rope") - def test_rope_single(self, mock_rope): - B, N, D = 2, 16, 1024 - x = torch.randn(B, N, D) - cos = torch.randn(B, N, 1, D) - sin = torch.randn(B, N, 1, D) - mock_rope.return_value = x.view(B, N, 1, D) - result = self.impl.rope_single(x, cos, sin) - self.assertEqual(result.shape[0], B) - self.assertEqual(result.shape[1], N) - self.assertEqual(result.shape[2], D) - mock_rope.assert_called_once() - - @patch( - "vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj" - ) - @patch("torch_npu._npu_paged_attention_mla") - def test_forward_decode_without_graph(self, mock_page_attention_mla, - mock_up_proj): - self.impl.running_in_graph = False - self.impl.running_chunkprefilll_with_torchair = False - num_tokens = 100 - num_blocks = 256 - block_size = 4 - q_nope = torch.randn(num_tokens, self.impl.num_heads, - self.impl.qk_nope_head_dim) - q_pe = torch.randn(num_tokens, self.impl.num_heads, - self.impl.qk_rope_head_dim) - kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size, - self.impl.num_heads, - self.impl.kv_lora_rank) - metadata = MagicMock() - metadata.decode = MagicMock() - metadata.decode.block_table = MagicMock() - metadata.decode.seq_lens = 10 - mock_page_attention_mla.return_value = torch.randn( - num_tokens, self.impl.num_heads, self.impl.kv_lora_rank) - mock_up_proj.return_value = torch.randn(num_tokens, - self.impl.num_heads, - self.impl.v_head_dim) - result = self.impl._forward_decode(q_nope, q_pe, None, None, - kv_c_and_k_pe_cache, metadata) - self.assertEqual(result.shape[0], num_tokens) - self.assertEqual(result.shape[1], self.impl.num_heads) - self.assertEqual(result.shape[2], self.impl.v_head_dim) - mock_up_proj.assert_called_once() - mock_page_attention_mla.assert_called_once() - - @patch( - "vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill" - ) - @patch("torch_npu._npu_reshape_and_cache") - def test_forward_without_graph(self, _, mock_forward_prefill): - self.impl.running_in_graph = False - self.impl.torchair_graph_enabled = False - - num_tokens = 100 - num_blocks = 256 - block_size = 4 - rotary_emb_return_value = (torch.randn(num_tokens, 16, - self.impl.kv_lora_rank), - torch.randn(0, 1, self.impl.kv_lora_rank)) - self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value - self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn( - 1, num_blocks, 128) - - hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank) - hidden_states_or_kv_c_normed = torch.randn(num_tokens, - self.impl.kv_lora_rank) - k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim) - kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads, - self.impl.kv_lora_rank), - torch.randn(num_blocks, block_size, self.impl.num_heads, - self.impl.qk_rope_head_dim)) - output = torch.randn(num_tokens, self.impl.num_heads, - self.impl.v_head_dim) - - metadata = MagicMock() - metadata.num_decodes = 0 - metadata.num_prefills = num_tokens - mock_forward_prefill.return_value = torch.randn( - 0, self.impl.num_heads * self.impl.v_head_dim) - result = self.impl.forward(None, hidden_states_or_q_c, - hidden_states_or_kv_c_normed, k_pe, - kv_cache, metadata, output, False) - self.assertEqual(result.shape[0], num_tokens) diff --git a/tests/ut/torchair/test_torchair_model_runner.py b/tests/ut/torchair/test_torchair_model_runner.py deleted file mode 100644 index bbbe82bb..00000000 --- a/tests/ut/torchair/test_torchair_model_runner.py +++ /dev/null @@ -1,45 +0,0 @@ -from unittest.mock import MagicMock, Mock - -import pytest -import torch -from pytest_mock import MockerFixture -from vllm.config import VllmConfig - -from tests.ut.base import PytestBase -from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner - - -class TestNPUTorchairModelRunner(PytestBase): - - @pytest.fixture - def setup_npu_torchair_model_runner(self, mocker: MockerFixture): - mocker.patch.object(NPUTorchairModelRunner, "__init__", - lambda self, *args, **kwargs: None) - runner = NPUTorchairModelRunner(Mock(), Mock()) - - runner.device = torch.device("cpu") - runner.vllm_config = MagicMock(spec=VllmConfig) - - runner.speculative_config = MagicMock( - method="mtp", - num_speculative_tokens=4, - disable_padded_drafter_batch=False) - - runner.ascend_config = MagicMock(enable_shared_expert_dp=False, - torchair_graph_config=MagicMock( - use_cached_graph=True, - graph_batch_sizes=[1, 2, 4])) - - runner.decode_token_per_req = 2 - runner.is_kv_consumer = True - runner.max_num_reqs = 100 - - runner.model_config = MagicMock(hf_config=MagicMock(index_topk=2)) - runner.attn_backend = MagicMock(get_builder_cls=lambda: Mock()) - - return runner - - def test_init(self, mocker: MockerFixture, - setup_npu_torchair_model_runner): - runner = setup_npu_torchair_model_runner - assert isinstance(runner, NPUTorchairModelRunner) diff --git a/tests/ut/torchair/test_torchair_mtp_proposer.py b/tests/ut/torchair/test_torchair_mtp_proposer.py deleted file mode 100644 index ec2dc425..00000000 --- a/tests/ut/torchair/test_torchair_mtp_proposer.py +++ /dev/null @@ -1,78 +0,0 @@ -from unittest.mock import MagicMock, Mock - -import pytest -import torch -from pytest_mock import MockerFixture -from vllm.config import CacheConfig, VllmConfig - -from tests.ut.base import PytestBase -from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer - - -class TestTorchairMtpProposer(PytestBase): - - @pytest.fixture - def setup_torchair_mtp_proposer(self, mocker: MockerFixture): - vllm_config = MagicMock(spec=VllmConfig) - vllm_config.device_config = MagicMock() - vllm_config.device_config.device = torch.device("cpu") - vllm_config.speculative_config = MagicMock() - vllm_config.speculative_config.draft_model_config = MagicMock() - vllm_config.speculative_config.draft_model_config.dtype = torch.float16 - vllm_config.speculative_config.method = "mtp" - vllm_config.speculative_config.num_speculative_tokens = 5 - vllm_config.load_config = MagicMock() - cache_config = CacheConfig(block_size=16) - vllm_config.cache_config = cache_config - vllm_config.scheduler_config = MagicMock(max_num_batched_tokens=1024, - max_num_seqs=64) - - device = torch.device("cpu") - runner = MagicMock() - runner.pcp_size = 1 - runner.dcp_size = 1 - runner.pcp_rank = 0 - runner.max_num_tokens = 1024 - runner.max_num_reqs = 10 - runner._use_aclgraph.return_value = True - - mocker.patch( - "vllm_ascend.torchair.torchair_mtp_proposer.MtpProposer.__init__", - return_value=None) - mock_set_default_dtype = mocker.patch( - 'vllm.utils.torch_utils.set_default_torch_dtype') - mock_set_default_dtype.return_value.__enter__.return_value = None - - mock_model_loader = MagicMock() - mocker.patch("vllm.model_executor.model_loader.get_model_loader", - return_value=mock_model_loader) - mock_layers = { - "target_attn_layer_1": Mock(), - "draft_attn_layer_2": Mock() - } - mocker.patch("vllm.config.get_layers_from_vllm_config", - return_value=mock_layers) - mock_set_current = mocker.patch("vllm.config.set_current_vllm_config") - mock_set_current.return_value.__enter__.return_value = None - mock_torchair_deepseek_mtp = MagicMock() - mock_torchair_deepseek_mtp.to.return_value = mock_torchair_deepseek_mtp - mocker.patch( - "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMTP", - return_value=mock_torchair_deepseek_mtp) - mocker.patch( - "vllm.model_executor.model_loader.utils.process_weights_after_loading" - ) - - proposer = TorchairMtpProposer(vllm_config, device, runner) - proposer.vllm_config = vllm_config - proposer.device = device - proposer.runner = runner - proposer.speculative_config = vllm_config.speculative_config - proposer.draft_model_config = vllm_config.speculative_config.draft_model_config - proposer.method = vllm_config.speculative_config.method - - return proposer, mock_model_loader, mock_torchair_deepseek_mtp - - def test_init(self, setup_torchair_mtp_proposer): - proposer, _, _, = setup_torchair_mtp_proposer - assert isinstance(proposer, TorchairMtpProposer) diff --git a/tests/ut/torchair/test_torchair_sfa.py b/tests/ut/torchair/test_torchair_sfa.py deleted file mode 100644 index 4552e877..00000000 --- a/tests/ut/torchair/test_torchair_sfa.py +++ /dev/null @@ -1,340 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.torchair.torchair_sfa import ( - AscendSFATorchairBackend, AscendSFATorchairDecodeMetadata, - AscendSFATorchairImpl, AscendSFATorchairMetadata, - AscendSFATorchairMetadataBuilder, AscendSFATorchairPrefillMetadata) - - -class TestAscendSFATorchairBackend(TestBase): - - def test_get_name(self): - self.assertEqual(AscendSFATorchairBackend.get_name(), - "ASCEND_SFA_TORCHAIR") - - def test_get_builder_cls(self): - self.assertEqual(AscendSFATorchairBackend.get_builder_cls(), - AscendSFATorchairMetadataBuilder) - - def test_get_kv_cache_shape(self): - result = AscendSFATorchairBackend.get_kv_cache_shape(2, 4, 8, 128) - self.assertEqual(result, (2, 4, 8, 128)) - - def test_get_impl_cls(self): - result = AscendSFATorchairBackend.get_impl_cls() - self.assertEqual(result, AscendSFATorchairImpl) - - -class TestAscendSFATorchairPrefillMetadata(TestBase): - - def test_ascend_sfa_prefill_metadata_default(self): - attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool) - query_lens = [1, 2] - seq_lens = [2, 2] - context_lens = torch.tensor([1, 2]) - input_positions = torch.tensor([0, 1, 0, 1]) - query_start_loc = torch.tensor([0, 1, 3]) - block_table = torch.tensor([[0, 1], [2, 3]]) - max_query_len = 2 - max_seq_lens = 2 - - metadata = AscendSFATorchairPrefillMetadata( - attn_mask=attn_mask, - query_lens=query_lens, - seq_lens=seq_lens, - context_lens=context_lens, - input_positions=input_positions, - query_start_loc=query_start_loc, - block_table=block_table, - max_query_len=max_query_len, - sin=None, - cos=None, - max_seq_lens=max_seq_lens) - self.assertIs(metadata.attn_mask, attn_mask) - self.assertEqual(metadata.query_lens, query_lens) - self.assertEqual(metadata.seq_lens, seq_lens) - self.assertIs(metadata.context_lens, context_lens) - self.assertIs(metadata.input_positions, input_positions) - self.assertIs(metadata.query_start_loc, query_start_loc) - self.assertIs(metadata.block_table, block_table) - self.assertEqual(metadata.max_query_len, max_query_len) - self.assertEqual(metadata.max_seq_lens, max_seq_lens) - self.assertIsNone(metadata.chunked_context) - - def test_ascend_sfa_prefill_metadata_with_chunked_context(self): - cu_seq_lens = torch.tensor([0, 2, 4]) - starts = torch.tensor([0, 2]) - seq_tot = [2, 2] - max_seq_lens = [2, 2] - workspace = torch.randn(2, 4) - chunk_seq_lens = torch.tensor([2, 2]) - - chunked_context = AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata( - cu_seq_lens=cu_seq_lens, - starts=starts, - seq_tot=seq_tot, - max_seq_lens=max_seq_lens, - workspace=workspace, - chunk_seq_lens=chunk_seq_lens) - - metadata = AscendSFATorchairPrefillMetadata( - attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), - query_lens=[1, 2], - seq_lens=[2, 2], - context_lens=torch.tensor([1, 2]), - input_positions=torch.tensor([0, 1, 0, 1]), - query_start_loc=torch.tensor([0, 1, 3]), - block_table=torch.tensor([[0, 1], [2, 3]]), - max_query_len=2, - max_seq_lens=2, - sin=None, - cos=None, - chunked_context=chunked_context) - - self.assertIsNotNone(metadata.chunked_context) - self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens) - self.assertIs(metadata.chunked_context.starts, starts) - self.assertEqual(metadata.chunked_context.seq_tot, seq_tot) - self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens) - self.assertIs(metadata.chunked_context.workspace, workspace) - self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens) - - -class TestAscendSFATorchairDecodeMetadata(TestBase): - - def test_ascend_sfa_decode_metadata_default(self): - input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) - block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]]) - seq_lens = torch.tensor([[2], [3]]) - max_seq_lens = 4 - seq_lens_list = [2, 3] - attn_mask = None - - metadata = AscendSFATorchairDecodeMetadata(input_positions, - block_table, seq_lens, - max_seq_lens, seq_lens_list, - None, None, attn_mask) - - self.assertIs(metadata.input_positions, input_positions) - self.assertIs(metadata.block_table, block_table) - self.assertIs(metadata.seq_lens, seq_lens) - self.assertEqual(metadata.max_seq_lens, max_seq_lens) - self.assertEqual(metadata.seq_lens_list, seq_lens_list) - self.assertIsNone(attn_mask) - - -class TestAscendSFATorchairMetadata(TestBase): - - def test_ascend_sfa_metadata_default(self): - num_actual_tokens = 100 - slot_mapping = torch.randn(100, 4, 1024) - query_start_loc = torch.tensor([1, 2, 3, 4]) - seq_lens = [30, 50] - block_tables = torch.randint(0, 100, (100, 4)) - - num_decodes = 4 - num_decode_tokens = 8 - num_prefills = 8 - - num_input_tokens = 2 - - query_lens = None - head_dim = None - attn_mask = None - attn_state = AscendAttentionState.ChunkedPrefill - - decode = None - prefill = None - - metadata = AscendSFATorchairMetadata( - num_actual_tokens, slot_mapping, query_start_loc, seq_lens, - block_tables, num_decodes, num_decode_tokens, num_prefills, - num_input_tokens, query_lens, head_dim, attn_mask, attn_state, - decode, prefill) - - self.assertEqual(metadata.num_actual_tokens, num_actual_tokens) - self.assertIs(metadata.slot_mapping, slot_mapping) - self.assertIs(metadata.query_start_loc, query_start_loc) - self.assertEqual(metadata.seq_lens, seq_lens) - self.assertIs(metadata.block_tables, block_tables) - self.assertEqual(metadata.num_decodes, num_decodes) - self.assertEqual(metadata.num_decode_tokens, num_decode_tokens) - self.assertEqual(metadata.num_prefills, num_prefills) - self.assertEqual(metadata.num_input_tokens, num_input_tokens) - self.assertEqual(metadata.query_lens, query_lens) - self.assertEqual(metadata.head_dim, head_dim) - self.assertEqual(metadata.attn_mask, attn_mask) - self.assertEqual(metadata.attn_state, attn_state) - self.assertEqual(metadata.decode, decode) - self.assertEqual(metadata.prefill, prefill) - - -class TestAscendSFATorchairMetadataBuilder(TestBase): - - def test_ascend_sfa_metadata_builder_default(self): - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - ascend_config = MagicMock() - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = True - with patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config", - return_value=ascend_config): - builder = AscendSFATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - self.assertEqual(builder.block_size, - mock_vllm_config.cache_config.block_size) - self.assertEqual( - builder.chunked_prefill_enabled, - mock_vllm_config.scheduler_config.enable_chunked_prefill) - self.assertEqual(builder.torchair_graph_enabled, True) - self.assertEqual(builder.max_blocks, (mock_vllm_config.model_config.max_model_len + - mock_vllm_config.cache_config.block_size - 1) \ - // mock_vllm_config.cache_config.block_size) - - @patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config") - def test_reorder_batch_with_torchair_graph(self, ascend_config): - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - ascend_config.torchair_graph_config = MagicMock() - ascend_config.torchair_graph_config.enabled = True - - builder = AscendSFATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - input_batch = MagicMock() - input_batch.req_ids = [0, 1, 2, 3] - - scheduler_output = MagicMock() - scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1} - scheduler_output.scheduled_spec_decode_tokens = { - 0: [1], - 1: [], - 2: [1, 1], - 3: [] - } - - input_batch.swap_states = MagicMock() - - modified = builder.reorder_batch(input_batch, scheduler_output) - - self.assertFalse(modified) - input_batch.swap_states.assert_not_called() - - @patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config") - def test_get_graph_runner_block_tables_normal(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - mock_device = torch.device('cpu') - - builder = AscendSFATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 64) - self.assertTrue(torch.equal(result[:, :10], block_tables)) - - @patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config") - def test_ge_graph_runner_block_tables_truncated(self, mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - - builder = AscendSFATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - builder.max_blocks = 4 - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 4) - self.assertTrue(torch.equal(result, block_tables[:, :4])) - - @patch("vllm_ascend.torchair.torchair_sfa.get_ascend_config") - def test_get_graph_runner_block_tables_from_numpy(self, - mock_ascend_config): - ascend_config = MagicMock() - mock_ascend_config.return_value = ascend_config - ascend_config.torchair_graph_config.enabled = False - mock_model_config = MagicMock() - mock_model_config.max_model_len = 1024 - mock_model_config.get_head_size.return_value = 64 - mock_model_config.dtype = torch.float16 - - mock_vllm_config = MagicMock() - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = MagicMock(block_size=16) - mock_vllm_config.scheduler_config = MagicMock( - max_num_seqs=4, enable_chunked_prefill=False) - mock_vllm_config.speculative_config = None - - mock_device = torch.device('cpu') - builder = AscendSFATorchairMetadataBuilder(None, None, - mock_vllm_config, - mock_device) - - block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32) - - result = builder._get_graph_runner_block_tables(3, block_tables) - - self.assertEqual(result.shape[0], 3) - self.assertEqual(result.shape[1], 64) - self.assertTrue(torch.equal(result[:, :10], block_tables)) diff --git a/tests/ut/torchair/test_torchair_worker.py b/tests/ut/torchair/test_torchair_worker.py deleted file mode 100644 index 0397aee1..00000000 --- a/tests/ut/torchair/test_torchair_worker.py +++ /dev/null @@ -1,111 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig - -from tests.ut.base import TestBase - -init_cache_hf_modules_path = "vllm.utils.import_utils.init_cached_hf_modules" - - -class TestNPUTorchairWorker(TestBase): - - def setUp(self): - self.cache_config_mock = MagicMock(spec=CacheConfig) - self.cache_config_mock.cache_type = "auto" - - self.model_config_mock = MagicMock(spec=ModelConfig) - self.model_config_mock.dtype = torch.float16 - self.model_config_mock.trust_remote_code = False - - self.hf_config_mock = MagicMock() - self.hf_config_mock.model_type = "test_model" - if hasattr(self.hf_config_mock, 'index_topk'): - delattr(self.hf_config_mock, 'index_topk') - - self.model_config_mock.hf_config = self.hf_config_mock - - self.parallel_config_mock = MagicMock(spec=ParallelConfig) - - self.vllm_config_mock = MagicMock(spec=VllmConfig) - self.vllm_config_mock.cache_config = self.cache_config_mock - self.vllm_config_mock.model_config = self.model_config_mock - self.vllm_config_mock.parallel_config = self.parallel_config_mock - self.vllm_config_mock.additional_config = None - self.vllm_config_mock.load_config = None - self.vllm_config_mock.scheduler_config = None - self.vllm_config_mock.device_config = None - self.vllm_config_mock.compilation_config = None - - self.local_rank = 0 - self.rank = 0 - self.distributed_init_method = "tcp://localhost:12345" - self.is_driver_worker = False - - @patch( - "vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment" - ) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform") - def test_init_device(self, mock_platform, mock_init_dist_env): - from vllm_ascend.worker.worker_v1 import NPUWorker - - mock_platform.mem_get_info.return_value = (1000, 2000) - - with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): - worker = NPUWorker() - worker.local_rank = 1 - worker.model_config = MagicMock() - worker.model_config.seed = 42 - worker.vllm_config = MagicMock() - worker.parallel_config = MagicMock() - worker.parallel_config.local_world_size = 0 - worker.parallel_config.data_parallel_size = 1 - - result = worker._init_device() - - mock_platform.set_device.assert_called_once() - call_args = mock_platform.set_device.call_args[0][0] - self.assertEqual(str(call_args), "npu:1") - - mock_platform.empty_cache.assert_called_once() - mock_platform.seed_everything.assert_called_once_with(42) - mock_platform.mem_get_info.assert_called_once() - mock_init_dist_env.assert_called_once() - - self.assertEqual(str(result), "npu:1") - self.assertEqual(worker.init_npu_memory, 1000) - - @patch( - "vllm_ascend.worker.worker_v1.NPUWorker._init_worker_distributed_environment" - ) - @patch("vllm_ascend.worker.worker_v1.NPUPlatform") - def test_init_device_torchair_worker(self, mock_platform, - mock_init_dist_env): - from vllm_ascend.torchair.torchair_worker import NPUTorchairWorker - - mock_platform.mem_get_info.return_value = (1000, 2000) - - with patch.object(NPUTorchairWorker, "__init__", - lambda x, **kwargs: None): - worker = NPUTorchairWorker() - worker.local_rank = 1 - worker.model_config = MagicMock() - worker.model_config.seed = 42 - worker.vllm_config = MagicMock() - worker.parallel_config = MagicMock() - worker.parallel_config.local_world_size = 0 - worker.parallel_config.data_parallel_size = 1 - - result = worker._init_device() - - mock_platform.set_device.assert_called_once() - call_args = mock_platform.set_device.call_args[0][0] - self.assertEqual(str(call_args), "npu:1") - - mock_platform.empty_cache.assert_called_once() - mock_platform.seed_everything.assert_called_once_with(42) - mock_platform.mem_get_info.assert_called_once() - mock_init_dist_env.assert_called_once() - - self.assertEqual(str(result), "npu:1") - self.assertEqual(worker.init_npu_memory, 1000) diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py deleted file mode 100644 index 02528519..00000000 --- a/tests/ut/torchair/test_utils.py +++ /dev/null @@ -1,164 +0,0 @@ -import os -from concurrent.futures import ThreadPoolExecutor -from unittest import mock -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.torchair import utils - - -class TestTorchairUtils(TestBase): - - def test_get_torchair_current_work_dir(self): - cache_dir = utils.TORCHAIR_CACHE_DIR - work_dir = utils._get_torchair_current_work_dir() - self.assertEqual(cache_dir, work_dir) - work_dir = utils._get_torchair_current_work_dir("test") - self.assertEqual(os.path.join(cache_dir, "test"), work_dir) - - def test_torchair_cache_dir(self): - utils.write_kv_cache_bytes_to_file(0, 100) - self.assertTrue(utils.check_torchair_cache_exist(), - "Create torchair cache dir failed") - self.assertTrue(utils.check_kv_cache_bytes_cache_exist(), - "Create kv cache bytes cache dir failed") - kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0) - self.assertEqual(100, kv_cache_bytes) - utils.delete_torchair_cache_file() - self.assertFalse(utils.check_torchair_cache_exist(), - "Delete torchair cache dir failed") - self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), - "Delete kv cache bytes cache dir failed") - - def test_torchair_cache_dir_multiple_ranks(self): - ranks = [0, 1, 2, 3] - values = [100, 200, 300, 400] - - with ThreadPoolExecutor() as executor: - executor.map(utils.write_kv_cache_bytes_to_file, ranks, values) - for rank, expected in zip(ranks, values): - self.assertEqual(expected, - utils.read_kv_cache_bytes_from_file(rank)) - utils.delete_torchair_cache_file() - - self.assertFalse(utils.check_torchair_cache_exist(), - "Delete torchair cache dir failed") - self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), - "Delete kv cache bytes cache dir failed") - - def test_delete_torchair_cache_file_multiple_times(self): - utils.write_kv_cache_bytes_to_file(0, 100) - utils.delete_torchair_cache_file() - for i in range(5): - try: - utils.delete_torchair_cache_file() - except FileNotFoundError: - self.fail( - f"Unexpected FileNotFoundError on delete call #{i+2}") - - @patch('vllm.ModelRegistry') - def test_register_torchair_model(self, mock_model_registry): - mock_registry = MagicMock() - mock_model_registry.return_value = mock_registry - utils.register_torchair_model() - - self.assertEqual(mock_model_registry.register_model.call_count, 7) - call_args_list = mock_model_registry.register_model.call_args_list - - expected_registrations = [ - ("DeepSeekMTPModel", - "vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP" - ), - ("DeepseekV2ForCausalLM", - "vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM" - ), - ("DeepseekV3ForCausalLM", - "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" - ), - ("DeepseekV32ForCausalLM", - "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" - ), - ("Qwen2ForCausalLM", - "vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"), - ("Qwen3MoeForCausalLM", - "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM" - ), - ("PanguProMoEForCausalLM", - "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" - ) - ] - - for i, (expected_name, - expected_path) in enumerate(expected_registrations): - args, kwargs = call_args_list[i] - self.assertEqual(args[0], expected_name) - self.assertEqual(args[1], expected_path) - - @mock.patch('vllm_ascend.torchair.utils.is_enable_nz') - @mock.patch('torch_npu.get_npu_format') - @mock.patch('torch_npu.npu_format_cast') - @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', - new=mock.MagicMock) - def test_converting_weight_acl_format_to_nz(self, mock_npu_cast, - mock_get_format, mock_is_nz): - ACL_FORMAT_FRACTAL_NZ = 29 - mock_get_format.return_value = 1 - mock_npu_cast.return_value = 1 - mock_is_nz.return_value = 1 - - fused_moe = mock.MagicMock() - fused_moe.w13_weight = mock.MagicMock() - fused_moe.w2_weight = mock.MagicMock() - fused_moe.w13_weight.data = torch.randn(128, 256) - fused_moe.w2_weight.data = torch.randn(256, 128) - model = mock.MagicMock() - model.modules.return_value = [fused_moe] - - utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) - self.assertEqual(fused_moe.w13_weight.data, 1) - - @mock.patch('torch_npu.get_npu_format') - @mock.patch('torch_npu.npu_format_cast') - @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', - new=mock.MagicMock) - def test_converting_weight_acl_format_format_true(self, mock_npu_cast, - mock_get_format): - ACL_FORMAT_FRACTAL_NZ = 29 - mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ - mock_npu_cast.return_value = 1 - - fused_moe = mock.MagicMock() - fused_moe.w13_weight = mock.MagicMock() - fused_moe.w2_weight = mock.MagicMock() - fused_moe.w13_weight.data = torch.randn(128, 256) - fused_moe.w2_weight.data = torch.randn(256, 128) - model = mock.MagicMock() - model.modules.return_value = [fused_moe] - - utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) - mock_npu_cast.assert_not_called() - - @mock.patch('vllm_ascend.torchair.utils.is_enable_nz') - @mock.patch('torch_npu.get_npu_format') - @mock.patch('torch_npu.npu_format_cast') - @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', - new=mock.MagicMock) - def test_converting_weight_acl_format_no_nz(self, mock_npu_cast, - mock_get_format, mock_is_nz): - ACL_FORMAT_FRACTAL_NZ = 29 - mock_get_format.return_value = 1 - mock_npu_cast.return_value = 1 - mock_is_nz.return_value = 0 - - fused_moe = mock.MagicMock() - fused_moe.w13_weight = mock.MagicMock() - fused_moe.w2_weight = mock.MagicMock() - fused_moe.w13_weight.data = torch.randn(128, 256) - fused_moe.w2_weight.data = torch.randn(256, 128) - model = mock.MagicMock() - model.modules.return_value = [fused_moe] - - utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) - mock_npu_cast.assert_not_called() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 3c96449e..54ef914a 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -18,15 +18,6 @@ from uuid import uuid4 from vllm.logger import logger -TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"] - - -def _check_torchair_supported(model_type: str): - for supported_model in TORCHAIR_MODEL_LIST: - if supported_model in model_type.lower(): - return True - return False - def check_kv_extra_config(vllm_config): @@ -66,11 +57,6 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} - torchair_graph_config = additional_config.get("torchair_graph_config", - {}) - - self.torchair_graph_config = TorchairGraphConfig( - torchair_graph_config, vllm_config, additional_config) xlite_graph_config = additional_config.get("xlite_graph_config", {}) self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, @@ -107,8 +93,8 @@ class AscendConfig: self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) self.enable_shared_expert_dp = additional_config.get( - "enable_shared_expert_dp", False - ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel + "enable_shared_expert_dp", + False) and vllm_config.parallel_config.enable_expert_parallel if self.enable_shared_expert_dp: from vllm_ascend.utils import enable_sp assert enable_sp(vllm_config=vllm_config, @@ -215,86 +201,6 @@ class AscendCompilationConfig: # Add more compilation related configs here as needed -class TorchairGraphConfig: - """ - Configuration Object for torchair_graph_config from additional_config - """ - - def __init__(self, torchair_graph_config, vllm_config, additional_config): - self.enabled = torchair_graph_config.get("enabled", False) - self.mode = torchair_graph_config.get("mode", '') - self.use_cached_graph = torchair_graph_config.get( - "use_cached_graph", False) - self.use_cached_kv_cache_bytes = torchair_graph_config.get( - "use_cached_kv_cache_bytes", False) - self.graph_batch_sizes = torchair_graph_config.get( - "graph_batch_sizes", []) - self.graph_batch_sizes_init = torchair_graph_config.get( - "graph_batch_sizes_init", False) - self.enable_multistream_mla = torchair_graph_config.get( - "enable_multistream_mla", False) - self.enable_view_optimize = torchair_graph_config.get( - "enable_view_optimize", True) - self.enable_frozen_parameter = torchair_graph_config.get( - "enable_frozen_parameter", True) - self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) - self.enable_super_kernel = torchair_graph_config.get( - "enable_super_kernel", False) - - if not isinstance(self.graph_batch_sizes, list): - raise TypeError("graph_batch_sizes must be list[int]") - if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0: - raise ValueError( - "graph_batch_sizes_init is only valid when graph_batch_sizes is empty" - ) - if not self.enabled: - if self.mode: - raise RuntimeError( - "mode is valid only when Torchair graph mode is enabled") - if self.use_cached_graph: - raise RuntimeError( - "use_cached_graph is valid only when Torchair graph mode is enabled" - ) - if self.use_cached_kv_cache_bytes: - raise RuntimeError( - "use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled" - ) - if self.graph_batch_sizes: - raise RuntimeError( - "graph_batch_sizes is valid only when Torchair graph mode is enabled" - ) - if self.graph_batch_sizes_init: - raise RuntimeError( - "graph_batch_sizes_init is valid only when Torchair graph mode is enabled" - ) - if self.enable_multistream_mla: - raise RuntimeError( - "enable_multistream_mla is valid only when Torchair graph mode is enabled" - ) - if self.enable_kv_nz: - raise RuntimeError( - "enable_kv_nz is valid only when Torchair graph mode is enabled" - ) - if self.enable_super_kernel: - raise RuntimeError( - "enable_super_kernel is valid only when Torchair graph mode is enabled" - ) - if self.enable_super_kernel: - if vllm_config.parallel_config.tensor_parallel_size != 1: - raise RuntimeError( - "enable_super_kernel is valid only when tensor_parallel_size is 1" - ) - if not additional_config.get("multistream_overlap_shared_expert", - False): - raise RuntimeError( - "enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled" - ) - if self.use_cached_kv_cache_bytes and not self.use_cached_graph: - raise RuntimeError( - "use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled" - ) - - class XliteGraphConfig: """ Configuration Object for xlite_graph_config from additional_config @@ -382,39 +288,7 @@ def get_ascend_config(): def check_ascend_config(vllm_config, enforce_eager): ascend_config = get_ascend_config() - # for eager mode - if enforce_eager: - # torchair_graph cannot be enabled with eager mode. - if ascend_config.torchair_graph_config.enabled: - raise RuntimeError( - "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode." - ) - # for graph mode - else: - # torchair_graph case - if ascend_config.torchair_graph_config.enabled: - # torchair_graph is supported for deepseek/pangu/qwen model only. - if vllm_config.model_config: - model_type = vllm_config.model_config.hf_config.model_type - if not _check_torchair_supported(model_type): - raise NotImplementedError( - "Torchair graph mode only works with following model types:" - f"{TORCHAIR_MODEL_LIST}.") - if ascend_config.enable_shared_expert_dp: - logger.warning( - "enable_shared_expert_dp is not supported for torchair graph mode currently, " - "it has been disabled automatically.") - # aclgraph case - else: - if ascend_config.ascend_compilation_config.enable_quantization_fusion: - logger.info( - "Quantization fusion enabled! op fusion on quantization are expected. " - ) - - if vllm_config.model_config: - model_type = vllm_config.model_config.hf_config.model_type - if "qwen" not in model_type: - logger.warning( - "ACL Graph is currently experimental. Please " - "raise an issue on https://github.com/vllm-project/vllm-ascend/issues" - " if you encourage any Error") + if ascend_config.ascend_compilation_config.enable_quantization_fusion: + logger.info( + "Quantization fusion enabled! op fusion on quantization are expected. " + ) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a6a6447c..348efc33 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -857,7 +857,6 @@ class AscendMLAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled - self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz vllm_config = get_current_vllm_config() self.ring_mla_mask_size = 512 @@ -1248,7 +1247,7 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -1276,7 +1275,7 @@ class AscendMLAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -1318,18 +1317,11 @@ class AscendMLAImpl(MLAAttentionImpl): # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - if self.enable_kv_nz: - k_nope = k_nope.view(-1, self.num_kv_heads, - self.kv_lora_rank // 16, block_size, 16) - k_pe = k_pe.view(-1, self.num_kv_heads, - self.qk_rope_head_dim // 16, block_size, 16) - input_layout = "BSND" - else: - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) - input_layout = "BNSD" + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + input_layout = "BNSD" if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, @@ -1346,14 +1338,9 @@ class AscendMLAImpl(MLAAttentionImpl): spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: - if self.enable_kv_nz: - q_nope = q_nope.view(num_tokens, 1, self.num_heads, - -1).contiguous() - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, - -1).contiguous() - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + q_nope = q_nope.view(num_tokens, self.num_heads, 1, + -1).contiguous() + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) sparse_mode = 0 spec_attn_mask = None diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 6ee35f8c..550f08ed 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -345,7 +345,6 @@ class AscendSFAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled - self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO assert self.indexer is not None, "Indexer is required for DSA." @@ -534,7 +533,7 @@ class AscendSFAImpl(MLAAttentionImpl): # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA" if self.enable_sfa_cp: assert slots_cp is not None diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 8303c2dc..58968c1f 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -453,7 +453,6 @@ class KVCacheRecvingThread(threading.Thread): def _cat_kv_cache(self, block_ids: list[list[int]]): # Get necessary parameters k_cache = list(self.kv_caches.values())[0][0] - kv_shape = k_cache.shape dtype = k_cache.dtype device = k_cache.device head_dim = self.model_config.hf_config.head_dim @@ -494,13 +493,6 @@ class KVCacheRecvingThread(threading.Thread): # Process each layer in the KV cache for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): - if len( - k_cache_layer.shape - ) == 3: # kv shape in torchair model is [num_block, block_size, num_kv_head*head_dim] - k_cache_layer = k_cache_layer.view(kv_shape[0], kv_shape[1], - num_kv_head, head_dim) - v_cache_layer = v_cache_layer.view(kv_shape[0], kv_shape[1], - num_kv_head, head_dim) # Load cache data into buffers torch_npu.atb.npu_paged_cache_load( k_cache_layer, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 138dcddf..32604a1f 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -99,8 +99,6 @@ class MoECommMethod(ABC): w2_scale: Optional[list[torch.Tensor]] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, - # For TorchAir graph - is_torchair: bool = False, # For Cube/Vector parallel shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, @@ -283,8 +281,6 @@ class FusedAlltoAllCommImpl(MoECommMethod): w2_scale: Optional[torch.Tensor] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, - # For TorchAir graph - is_torchair: bool = False, # For Cube/Vector parallel shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 532883e5..c9abc8ff 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -26,10 +26,7 @@ from vllm.platforms import Platform, PlatformEnum # todo: please remove it when solve cuda hard code in vllm os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" -from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, - init_ascend_config) -from vllm_ascend.torchair.utils import (check_torchair_cache_exist, - delete_torchair_cache_file) +from vllm_ascend.ascend_config import check_ascend_config, init_ascend_config from vllm_ascend.utils import refresh_block_size # isort: off @@ -204,25 +201,6 @@ class NPUPlatform(Platform): compilation_config.mode) compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. - if ascend_config.torchair_graph_config.enabled: - logger.info( - "Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE" - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension - # mismatches or configuration inconsistencies when users reuse cached computation graphs. Though - # this will increase graph compilation duration, it significantly enhances robustness and decreases - # graph launching time during inference. - if check_torchair_cache_exist( - ) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: - logger.warning( - "Torchair cache folder is deleted here to prevent runtime issues caused by dimension " - "mismatches or configuration inconsistencies when users reuse cached computation graphs. " - "In order to decrease torchair graph compilation time, users can enable both use_cached_graph " - "and use_cached_kv_cache_bytes in torchair_graph_config.") - delete_torchair_cache_file() - # set cudaprah sizes before extending `compilation_config.splitting_ops` vllm_config._set_cudagraph_sizes() # There are cases where default cudagraph_capture_sizes are not friendly @@ -303,9 +281,7 @@ class NPUPlatform(Platform): if parallel_config and parallel_config.worker_cls == "auto": # TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm. parallel_config.all2all_backend = "flashinfer_all2allv" - if ascend_config.torchair_graph_config.enabled: - parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker" - elif ascend_config.xlite_graph_config.enabled: + if ascend_config.xlite_graph_config.enabled: logger.info( "Euler Xlite enabled. See: https://gitee.com/openeuler/GVirt/tree/master/xlite" ) @@ -390,29 +366,14 @@ class NPUPlatform(Platform): use_sparse=False, attn_type: str | None = None, ): - ascend_config = get_ascend_config() - - if use_mla and ascend_config.enable_shared_expert_dp: - if use_mla and use_sparse: - return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" - - use_torchair = ascend_config.torchair_graph_config.enabled - # choose attention backend based on use_mla and use_torchair + # choose attention backend based on use_mla backend_map = { - (True, False, True): - "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend", - (True, False, False): - "vllm_ascend.attention.mla_v1.AscendMLABackend", - (False, False, True): - "vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend", - (False, False, False): + (True, False): "vllm_ascend.attention.mla_v1.AscendMLABackend", + (False, False): "vllm_ascend.attention.attention_v1.AscendAttentionBackend", - (True, True, False): - "vllm_ascend.attention.sfa_v1.AscendSFABackend", - (True, True, True): - "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend", + (True, True): "vllm_ascend.attention.sfa_v1.AscendSFABackend", } - return backend_map[(use_mla, use_sparse, use_torchair)] + return backend_map[(use_mla, use_sparse)] @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 2e86dd6e..6fd03857 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -111,10 +111,9 @@ class AscendW8A8DynamicFusedMoEMethod: vllm_config = get_current_vllm_config() ascend_config = get_ascend_config() - self.use_aclgraph = ( - vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE - and not vllm_config.model_config.enforce_eager - and not ascend_config.torchair_graph_config.enabled) + self.use_aclgraph = (vllm_config.compilation_config.mode + == CompilationMode.VLLM_COMPILE + and not vllm_config.model_config.enforce_eager) self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path self.in_dtype = vllm_config.model_config.dtype diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 50f65de7..df5015f1 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -20,21 +20,14 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.ngram_proposer import NgramProposer from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer -from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer -def get_spec_decode_method(method, - vllm_config, - device, - runner, - is_torchair_graph=False): +def get_spec_decode_method(method, vllm_config, device, runner): if method == "ngram": return NgramProposer(vllm_config, device, runner) elif method in ("eagle", "eagle3"): return EagleProposer(vllm_config, device, runner) elif method == "mtp": - if is_torchair_graph: - return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) elif method == 'suffix': return SuffixDecodingProposer(vllm_config, device, runner) diff --git a/vllm_ascend/torchair/__init__.py b/vllm_ascend/torchair/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vllm_ascend/torchair/models/__init__.py b/vllm_ascend/torchair/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py deleted file mode 100644 index bc1525d9..00000000 --- a/vllm_ascend/torchair/models/qwen2.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -from collections.abc import Iterable -from typing import Any, List, Optional, Union - -import torch -import torch.nn.functional as F -import vllm -from torch import nn -from transformers import Qwen2Config -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_reduce_scatter) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401 -from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401 -from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model -from vllm.model_executor.models.utils import (AutoWeightsLoader, - PPMissingLayer, maybe_prefix) -from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import set_default_rope_theta - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_v1 import AscendAttentionState - - -def all_gather_and_maybe_unpad( - hidden_states: torch.Tensor, - pad_size: int, -) -> torch.Tensor: - hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) - if pad_size > 0: - return hidden_states[:-pad_size, :] - return hidden_states - - -def maybe_pad_and_reduce_scatter( - hidden_states: torch.Tensor, - pad_size: int, -) -> torch.Tensor: - if pad_size > 0: - hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) - hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) - return hidden_states - - -class CustomQwen2Attention(Qwen2Attention): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_parameters: Optional[dict[str, Any]] = None, - max_position: int = 4096 * 32, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, - ) -> None: - super().__init__( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - max_position=max_position, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - attn_type=attn_type, - dual_chunk_attention_config=dual_chunk_attention_config, - rope_parameters=rope_parameters) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - q, k = self.rotary_emb(positions, - q, - k, - is_prefill=False, - is_qwen_torchair=True) - forward_kwargs = {} - output_shape = q.shape - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - forward_kwargs['output'] = output - - attn_output = self.attn.impl.forward(self.attn, - q, - k, - v, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - **forward_kwargs) - output, _ = self.o_proj(attn_output) - return output - else: - if type(self.rotary_emb) is RotaryEmbedding: - q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) - else: - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class CustomQwen2DecoderLayer(nn.Module): - - def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - - set_default_rope_theta(config, default_theta=1000000) - - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) - - # By default, Qwen2 uses causal attention as it is a decoder-only model. - # You can override the HF config with `is_causal=False` to enable - # bidirectional attention, which is used in some embedding models - # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) - if getattr(config, "is_causal", True): - attn_type = AttentionType.DECODER - else: - attn_type = AttentionType.ENCODER_ONLY - - self.self_attn = CustomQwen2Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_parameters=config.rope_parameters, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - attn_type=attn_type, - dual_chunk_attention_config=dual_chunk_attention_config, - ) - self.mlp = Qwen2MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }) -class CustomQwen2Model(Qwen2Model): - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - decoder_layer_type=decoder_layer_type) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - kv_cache = kv_caches[i - self.start_layer] \ - if kv_caches is not None else None - hidden_states, residual = layer(positions, - hidden_states, - residual, - kv_cache=kv_cache, - attn_metadata=attn_metadata) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - # add `CustomQwen2Model` to init self.model - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = CustomQwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata=None, # type: ignore - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) - - -vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py deleted file mode 100644 index 8338946f..00000000 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ /dev/null @@ -1,527 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2024 The Qwen team. -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from vllm/model_executor/models/qwen3_moe.py -# This file is a part of the vllm-ascend project. -from typing import Any, List, Optional, Union - -import torch -from torch import nn -from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, CompilationMode, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_tp_group) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.interfaces import (MixtureOfExperts, - SupportsLoRA, SupportsPP) -from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, - Qwen3MoeDecoderLayer, - Qwen3MoeForCausalLM, - Qwen3MoeMLP, Qwen3MoeModel, - Qwen3MoeSparseMoeBlock) -from vllm.model_executor.models.utils import ( - PPMissingLayer, extract_layer_index, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.sequence import IntermediateTensors - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, - init_metadata_for_sp) -from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE - - -class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - nn.Module.__init__(self) - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - - self.experts = TorchairAscendFusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - ) - - self.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - - def forward( - self, - hidden_states, - attn_metadata=None, - _metadata_for_padding: Optional[MetadataForPadding] = None, - ): - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = get_forward_context().in_profile_run - is_prefill = get_forward_context().with_prefill - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=self.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=None, - _metadata_for_padding=_metadata_for_padding, - ) - - return hidden_states - - -class CustomQwen3MoeAttention(Qwen3MoeAttention): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_parameters: dict[str, Any], - max_position_embeddings: int = 8192, - head_dim: Optional[int] = None, - rms_norm_eps: float = 1e-06, - qkv_bias: bool = False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = head_dim or (hidden_size // self.total_num_heads) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") - - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - rope_parameters=rope_parameters, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - @staticmethod - def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int, - head_dim: int, q_norm, k_norm): - q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) - - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) - q_by_head = q_norm(q_by_head) - q = q_by_head.view(q.shape) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) - k_by_head = k_norm(k_by_head) - k = k_by_head.view(k.shape) - - return q, k, v - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size, - self.head_dim, self.q_norm, self.k_norm) - - if (self.torchair_graph_enabled and attn_metadata is not None and - attn_metadata.attn_state == AscendAttentionState.DecodeOnly): - q, k = self.rotary_emb(positions, - q, - k, - is_prefill=False, - is_qwen_torchair=True) - forward_kwargs = {} - output_shape = q.shape - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - forward_kwargs['output'] = output - - attn_output = self.attn.impl.forward(self.attn, - q, - k, - v, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - **forward_kwargs) - output, _ = self.o_proj(attn_output) - return output - else: - q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): - - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - vllm_config: Optional[VllmConfig] = None, - prefix: str = "", - ) -> None: - - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = CustomQwen3MoeAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_parameters=config.rope_parameters, - max_position_embeddings=max_position_embeddings, - rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - # `mlp_only_layers` in the config. - layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) - self.use_aclgraph = (vllm_config is not None - and vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE - and not vllm_config.model_config.enforce_eager) - if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): - if not self.use_aclgraph: - # FIXME: custom sparse moe block doesn't work with aclgraph. - self.mlp = CustomSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - else: - self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, - prefix=f"{prefix}.mlp") - else: - self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - self.enable_sequence_parallelism = ( - vllm_config.compilation_config.pass_config.enable_sp - if vllm_config is not None else False) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - _metadata_for_padding: Optional[MetadataForPadding] = None, - ) -> torch.Tensor: - - # To prevent precision issues during the decoder phase when only prefilling enables SP - if not self.enable_sequence_parallelism: - self.self_attn.o_proj.reduce_results = True - else: - self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True - - # Self Attention - if residual is None: - residual = hidden_states - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - residual = _metadata_for_padding.padding_slice(residual) - - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.allgather_unpadding_aligned( - hidden_states) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter( - hidden_states) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if not self.use_aclgraph: - hidden_states = self.mlp( - hidden_states, _metadata_for_padding=_metadata_for_padding) - else: - hidden_states = self.mlp(hidden_states) - - return hidden_states, residual - - -@support_torch_compile -class CustomQwen3MoeModel(Qwen3MoeModel): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - parallel_config = vllm_config.parallel_config - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.config = config - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=f"{prefix}.embed_tokens") - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: CustomQwen3MoeDecoderLayer( - config=config, - cache_config=cache_config, - quant_config=quant_config, - vllm_config=vllm_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - _metadata_for_padding: Optional[MetadataForPadding] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - residual, - kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata, - _metadata_for_padding=_metadata_for_padding) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - - if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: - hidden_states = _metadata_for_padding.allgather_unpadding_aligned( - hidden_states) - - return hidden_states - - -class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - SupportsPP.__init__(self) - SupportsLoRA.__init__(self) - MixtureOfExperts.__init__(self) - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = CustomQwen3MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sp - # Set MoE hyperparameters - self.expert_weights: list[torch.Tensor] = [] - - self.moe_layers: list[FusedMoE] = [] - example_layer = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, Qwen3MoeDecoderLayer) - if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): - example_layer = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_layer is None: - raise RuntimeError("No Qwen3MoE layer found in the model.layers.") - - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - _metadata_for_padding = init_metadata_for_sp( - input_ids, self.enable_sequence_parallelism) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds, _metadata_for_padding) - return hidden_states diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py deleted file mode 100644 index 4af6f220..00000000 --- a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py +++ /dev/null @@ -1,221 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Adapted from vllm/model_executor/models/deepseek_mtp.py -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import torch -import torch.nn as nn -from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.deepseek_mtp import ( - DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, - SharedHead) -from vllm.model_executor.models.utils import maybe_prefix -from vllm.sequence import IntermediateTensors - -from vllm_ascend.torchair.models.torchair_deepseek_v2 import \ - TorchairDeepseekV2DecoderLayer - - -class TorchairDeepSeekShareHead(SharedHead): - - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - nn.Module.__init__(self) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) - - -class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer - ): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - nn.Module.__init__(self) - - self.tp_size = get_tensor_model_parallel_world_size() - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.shared_head = TorchairDeepSeekShareHead(config=config, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, - "shared_head")) - self.mtp_block = TorchairDeepseekV2DecoderLayer( - config, prefix, model_config, cache_config, quant_config) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - spec_step_index: int = 0, - ) -> torch.Tensor: - assert inputs_embeds is not None - # masking inputs at position 0, as not needed by MTP - inputs_embeds = torch.where((positions == 0).unsqueeze(-1), - torch.zeros_like(inputs_embeds), - inputs_embeds) - inputs_embeds = self.enorm(inputs_embeds) - previous_hidden_states = self.hnorm(previous_hidden_states) - - hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) - - del inputs_embeds, previous_hidden_states - replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 - - hidden_states, residual = self.mtp_block( - positions=positions, - hidden_states=hidden_states, - residual=None, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - replace_allreduce=replace_allreduce) - hidden_states = residual + hidden_states - return hidden_states - - -class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - self.mtp_start_layer_idx = config.num_hidden_layers - self.num_mtp_layers = config.num_nextn_predict_layers - # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - TorchairDeepSeekMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - - # Note: torch._dynamo.exc.Unsupported: builtin: str - self.layers_list = [ - self.layers[str(idx)] - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - ] - self.logits_processor = LogitsProcessor(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: torch.Tensor, - attn_metadata: AttentionMetadata, - previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - spec_step_idx: int = 0, - ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) - step_kv_cache = kv_caches[ - current_step_idx] if kv_caches is not None else None - return self.layers_list[current_step_idx]( - input_ids, - positions, - step_kv_cache, - attn_metadata, - previous_hidden_states, - inputs_embeds, - current_step_idx, - ) - - def compute_logits( - self, - hidden_states: torch.Tensor, - spec_step_idx: int = 0, - ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers_list[current_step_idx] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states)) - return logits - - -@support_torch_compile -class TorchairDeepSeekMTP(DeepSeekMTP): - # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; - # NOTE 2.The description file generated by the current msmodelslim tool does not have - # MTP layer info. Please manually add it and set the value to FLOAT. - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - self.config = vllm_config.model_config.hf_config - self.model = TorchairDeepSeekMultiTokenPredictor( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - hidden_states: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - spec_step_idx: int = 0, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, hidden_states, inputs_embeds, - spec_step_idx) - return hidden_states diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py deleted file mode 100644 index c29c440b..00000000 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ /dev/null @@ -1,1339 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # Adapted from -# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py -# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py -# """Inference-only DeepseekV2/DeepseekV3 model.""" - -from typing import Callable, Iterable, List, Optional, Tuple, Union - -import torch -import torch_npu -from torch import nn -from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import MLAAttention -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_group, split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) -from vllm.distributed.parallel_state import get_dp_group, get_ep_group -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import \ - DeepseekV2ForCausalLM # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, - get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.sequence import IntermediateTensors - -from vllm_ascend import envs -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch -from vllm_ascend.quantization.quant_config import AscendLinearMethod -from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ - TorchairAscendW8A8DynamicLinearMethod -from vllm_ascend.utils import dispose_tensor, oproj_tp_enable - - -class Indexer(nn.Module): - - def __init__(self, - config, - dim: int = 7168, - n_heads: int = 64, - head_dim: int = 128, - index_topk: int = 2048, - q_lora_rank: int = 1536, - rope_head_dim: int = 64, - quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = ""): - super().__init__() - - self.dim: int = dim # 7168 - self.n_heads: int = n_heads # 64 - self.head_dim: int = head_dim # 128 - self.rope_head_dim: int = rope_head_dim # 64 - self.index_topk: int = index_topk # 2048 - self.q_lora_rank: int = q_lora_rank # 1536 - self.wq_b = ReplicatedLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wq_b", - return_bias=False, - ) - self.wk = ReplicatedLinear( - self.dim, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk", - return_bias=False, - ) - self.weights_proj = ReplicatedLinear( - self.dim, - self.n_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.weights_proj", - return_bias=False, - ) - self.k_norm = nn.LayerNorm(self.head_dim) - self.softmax_scale = self.head_dim**-0.5 - - def forward(self): - return - - -class TorchairDeepseekV2SiluAndMul(SiluAndMul): - - def __init__(self, - *, - weight_scale: Optional[Callable[[], torch.Tensor]] = None): - super().__init__() - self.weight_scale = weight_scale - - def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor, - torch.Tensor]]): - if isinstance(x, tuple): - assert self.weight_scale is not None - # For AscendW8A8DynamicLinearMethod: - # a dynamic scale is passed along with the quantized value. - quantized_x, dynamic_scale = x - return torch_npu.npu_dequant_swiglu_quant( - x=quantized_x, - weight_scale=self.weight_scale(), - activation_scale=dynamic_scale, - activate_left=True, - quant_mode=1) - else: - return super().forward_oot(x) - - -class TorchairDeepseekV2MergedReplicatedLinear(ReplicatedLinear): - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - self.output_sizes = output_sizes - super().__init__(input_size, - sum(output_sizes), - bias=bias, - quant_config=quant_config, - prefix=prefix) - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, loaded_shard_id: int): - # With no support for GGUF format yet. - assert not getattr(param, "is_gguf_weight", False) - assert not getattr(param, "is_gguf_weight_type", False) - - assert loaded_shard_id < len(self.output_sizes) - shard_offset = sum(self.output_sizes[:loaded_shard_id]) - shard_size = self.output_sizes[loaded_shard_id] - shard = param.data.narrow(param.output_dim, shard_offset, shard_size) - - assert shard.size() == loaded_weight.size(), ( - f"Tried to load weights of size {loaded_weight.size()}" - f"to a parameter shard of id {loaded_shard_id} size {shard.size()}" - ) - shard.copy_(loaded_weight) - - -class TorchairDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): - - def forward( - self, - input_, - is_prefill=True, - is_force_scatter=False - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - forward_context = get_forward_context() - if self.reduce_results and self.tp_size > 1: - num_tokens = output_parallel.shape[0] - if is_force_scatter and num_tokens % self.tp_size: - output_parallel = nn.functional.pad( - output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) - if is_force_scatter or (not forward_context.with_prefill - and output_parallel.shape[0] % self.tp_size - == 0): - output = tensor_model_parallel_reduce_scatter(output_parallel, - dim=0) - else: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - -class TorchairDeepseekV2RowParallelLinear(RowParallelLinear): - - def forward( - self, - input_, - is_prefill=True, - is_force_scatter=False - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - -class TorchairDeepseekV2MLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - force_replicate: bool = False, - prefix: str = "", - ) -> None: - super().__init__() - if not force_replicate: - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") - else: - self.gate_up_proj = TorchairDeepseekV2MergedReplicatedLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = ReplicatedLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - - quant_method = self.gate_up_proj.quant_method - if isinstance(quant_method, UnquantizedLinearMethod): - self.act_fn = TorchairDeepseekV2SiluAndMul() - elif (isinstance(quant_method, AscendLinearMethod) - and isinstance(quant_method.quant_method, - TorchairAscendW8A8DynamicLinearMethod)): - # TODO(sdmyzlp): Currently preserved as before: - # 1. The only quantization supported for silu is W8A8Dynamic - # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16 - # - # Maybe one can implement a better and more general configuration - # scheme, e.g. by somehow passing around the tweaked `quant_config` - self.act_fn = TorchairDeepseekV2SiluAndMul( - # Use lazy binding, for `weight_scale_fp32` is accessible - # only after `process_weights_after_loading`. - weight_scale=lambda: self.gate_up_proj.weight_scale_fp32) - # To be consumed by AscendW8A8DynamicLinearMethod.apply() - self.gate_up_proj._ascend_quant_config = { - "output_dtype": torch.int32, - "pertoken_scale": False, - "return_scale": True, - } - self.down_proj._ascend_quant_config = { - "output_dtype": torch.bfloat16, - "pertoken_scale": True, - "return_scale": False, - } - else: - raise NotImplementedError( - f"Quantization with [{type(quant_method)}] is NOT supported") - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class TorchairDeepseekV2MoE(nn.Module): - - top_k: int - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") - - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.multistream_overlap_shared_expert = \ - ascend_config.multistream_overlap_shared_expert and \ - self.torchair_graph_enabled - - self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel - self.params_dtype = torch.float32 if self.enable_super_kernel else \ - torch.get_default_dtype() - # Converting gate weight to fp32 is to adapt to the super kernel feature. - # Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add. - # In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem, - # modifications will be made in the initialization stage. - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - params_dtype=self.params_dtype, - prefix=f"{prefix}.gate") - if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=self.params_dtype)) - else: - self.gate.e_score_correction_bias = None - - self.experts = TorchairAscendFusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) - - if config.n_shared_experts is not None: - self.all_reduce_merge = self.experts.all_reduce_merge - reduce_results = not self.all_reduce_merge - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.shared_experts = TorchairDeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=reduce_results, - force_replicate=self.multistream_overlap_shared_expert - or enable_shared_expert_dp, - prefix=f"{prefix}.shared_experts", - ) - else: - self.shared_experts = None # type: ignore - TorchairDeepseekV2MoE.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - self.rm_router_logits = self.experts.rm_router_logits - - def forward(self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None, - replace_allreduce: bool = False) -> torch.Tensor: - - forward_context = get_forward_context() - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - - enable_force_load_balance = forward_context.in_profile_run - - is_prefill = forward_context.with_prefill - - # router_logits: (num_tokens, n_experts) - router_logits = None - if not self.rm_router_logits and not self.multistream_overlap_shared_expert: - router_logits, _ = self.gate(hidden_states) - - experts_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=TorchairDeepseekV2MoE.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=self.shared_experts, - gate=self.gate, - replace_allreduce=replace_allreduce) - - hidden_states = ( - experts_hidden_states[0] * self.routed_scaling_factor + - experts_hidden_states[1]) - if self.all_reduce_merge: - # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - return hidden_states - - -class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - decoder_layer=None, - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - self.tp_size = get_tensor_model_parallel_world_size() - assert num_heads % self.tp_size == 0 - self.num_local_heads = num_heads // self.tp_size - self.layers = config.num_hidden_layers - self.first_k_dense_replace = config.first_k_dense_replace - - self.scaling = self.qk_head_dim**-0.5 - self.max_position_embeddings = max_position_embeddings - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - - if oproj_tp_enable(): - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - elif (config.n_routed_experts is not None - and self.debug_layer_idx >= config.first_k_dense_replace - and self.debug_layer_idx % config.moe_layer_freq == 0 - and (ascend_config.multistream_overlap_shared_expert - or self.enable_shared_expert_dp)): - self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - else: - self.o_proj = TorchairDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - rope_parameters=config.rope_parameters, - is_neox_style=False) - if config.rope_parameters["rope_type"] != "default": - mscale_all_dim = config.rope_parameters.get( - "mscale_all_dim", False) - scaling_factor = config.rope_parameters["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = MLAAttention( - num_heads=self.num_local_heads, - scale=self.scaling, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_sparse=False, - indexer=None, - # MLA Args - rotary_emb=self.rotary_emb, - q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - enable_multistream_mla = (self.enable_multistream_mla - and attn_metadata is not None - and not forward_context.with_prefill - and attn_metadata.num_decodes > 0) - forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} - if self.q_lora_rank is not None: - maybe_npu_prefetch(self.q_a_proj.weight, - hidden_states, - enabled=enable_multistream_mla) - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) - forward_kwargs['ckq'] = ckq - else: - hidden_states_or_q_c = hidden_states - if self.torchair_graph_enabled: - output_shape = hidden_states.shape - output = torch.empty(output_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - forward_kwargs['output'] = output - output = self.mla_attn.impl.forward(self.mla_attn, - hidden_states_or_q_c, - hidden_states, None, kv_cache, - attn_metadata, - **forward_kwargs) - output = output.view(-1, output_shape[-1]) - return output - else: - kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - hidden_states_or_q_c = get_tp_group().all_gather( - hidden_states_or_q_c, 0) - kv_no_split = get_tp_group().all_gather(kv_no_split, 0) - - kv_c, k_pe = kv_no_split.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: - output_shape = hidden_states.shape - else: - num_tokens = hidden_states_or_q_c.shape[0] - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=output_shape) - - -class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - decoder_layer=None, - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - self.tp_size = get_tensor_model_parallel_world_size() - assert num_heads % self.tp_size == 0 - self.num_local_heads = num_heads // self.tp_size - self.layers = config.num_hidden_layers - self.first_k_dense_replace = config.first_k_dense_replace - - self.scaling = self.qk_head_dim**-0.5 - self.max_position_embeddings = max_position_embeddings - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj", - return_bias=False, - ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj", - return_bias=False, - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj", - return_bias=False, - ) - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa", - return_bias=False, - ) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj", - return_bias=False, - ) - if (config.n_routed_experts is not None - and self.debug_layer_idx >= config.first_k_dense_replace - and self.debug_layer_idx % config.moe_layer_freq == 0 - and (ascend_config.multistream_overlap_shared_expert - or self.enable_shared_expert_dp)): - self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - return_bias=False, - ) - else: - self.o_proj = TorchairDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - return_bias=False, - ) - - if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( - qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - rope_parameters=config.rope_parameters, - is_neox_style=False, - ) - if config.rope_parameters["rope_type"] != "default": - mscale_all_dim = config.rope_parameters.get( - "mscale_all_dim", False) - scaling_factor = config.rope_parameters["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - self.dim: int = config.hidden_size # 7168 - # TODO(zzzzwwjj): wait transformers add these params - self.n_heads: int = 64 # 64 - self.head_dim: int = 128 # 128 - self.index_topk: int = 2048 # 2048 - self.indexer = Indexer( - config, - quant_config=quant_config, - dim=self.dim, - n_heads=self.n_heads, - head_dim=self.head_dim, - index_topk=self.index_topk, - prefix=f"{prefix}.indexer", - ) - self.sfa_attn = MLAAttention( - num_heads=self.num_local_heads, - scale=self.scaling, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_sparse=True, - indexer=self.indexer, - # MLA Args - rotary_emb=self.rotary_emb, - q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if not self.torchair_graph_enabled: - if forward_context.attn_metadata is not None and isinstance( - forward_context.attn_metadata, dict): - attn_metadata = next( - iter(forward_context.attn_metadata.values()), None) - else: - attn_metadata = forward_context.attn_metadata - if kv_cache is None: - kv_cache = self.sfa_attn.kv_cache[ - forward_context.virtual_engine] - - num_tokens = hidden_states.shape[0] - need_gather_q_kv = False - # if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - # # Simulate all gather to calculate output shape - # num_tokens = num_tokens * self.tp_size - # need_gather_q_kv = True - if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace: - output_shape = hidden_states.shape - if self.enable_shared_expert_dp and ( - self.debug_layer_idx == self.first_k_dense_replace - or self.debug_layer_idx == self.layers): - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) - output = torch.empty(output_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, - need_gather_q_kv, output) - output = output.view(-1, output_shape[-1]) - return output - - -class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - self.layer_idx = layer_idx - self.layers = config.num_hidden_layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tp_group().rank_in_group - ascend_config = get_ascend_config() - self.use_mla = False - self.use_sparse = False - # TODO: enable mla in vllm-ascend - if model_config.use_mla: - if hasattr(model_config.hf_config, "index_topk"): - attn_cls = TorchairDeepseekV2SFAAttention - self.use_sparse = True - else: - attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment] - self.use_mla = True - else: - attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - decoder_layer=self, - ) - - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = TorchairDeepseekV2MoE( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.mla_moe_communication = ascend_config.multistream_overlap_shared_expert \ - and model_config.use_mla and self.tp_size > 1 - else: - self.mlp = TorchairDeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.mla_moe_communication = False - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.routed_scaling_factor = config.routed_scaling_factor - self.first_k_dense_replace = config.first_k_dense_replace - self.tp_group = get_tp_group().device_group - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - replace_allreduce: bool = False, - ) -> torch.Tensor: - # Self Attention - forward_context = get_forward_context() - if attn_metadata is not None: - decoding_condition_met = ( - not attn_metadata.is_prefill if self.use_sparse else - not forward_context.with_prefill if self.use_mla else False) - mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce - else: - mla_moe_communication = False - - if (envs.VLLM_ASCEND_ENABLE_MLAPO - and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention) - and attn_metadata is not None - and not forward_context.with_prefill): - if residual is not None: - hidden_states = hidden_states + residual - residual = hidden_states - else: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - previous_hidden_states, previous_residual = hidden_states, residual - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - # Dispose hidden_states and residual from the previous layer - # to save npu memory because they're no longer used. - dispose_tensor(previous_hidden_states) - dispose_tensor(previous_residual) - if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers: - hidden_states = tensor_model_parallel_all_gather(hidden_states, - dim=0) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - if mla_moe_communication and residual.shape[0] != hidden_states.shape[ - 0]: - chunk_hidden_states = torch.tensor_split(residual, - self.tp_size, - dim=0) - residual = chunk_hidden_states[self.tp_rank] - - if hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # We scale both hidden_states and residual before - # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor - if self.layer_idx == 0: - # The residual is shared by all layers, we only scale it on - # first layer. - residual *= 1. / self.routed_scaling_factor - - tp_size = get_tensor_model_parallel_world_size() - if self.enable_shared_expert_dp and ( - self.layer_idx == self.first_k_dense_replace - or self.layer_idx == self.layers) and tp_size > 1: - num_tokens, _ = residual.shape - if num_tokens % tp_size: - residual = nn.functional.pad(residual, - (0, 0, 0, -num_tokens % tp_size)) - chunk_residual = torch.tensor_split(residual, tp_size, dim=0) - tp_rank = get_tensor_model_parallel_rank() - residual = chunk_residual[tp_rank] - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if isinstance(self.mlp, TorchairDeepseekV2MoE): - hidden_states = self.mlp(hidden_states, - attn_metadata, - replace_allreduce=mla_moe_communication) - else: - hidden_states = self.mlp(hidden_states) - - if isinstance(self.mlp, TorchairDeepseekV2MLP - ) and hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor - if mla_moe_communication and self.layer_idx >= self.layers - 1: - hidden_states = tensor_model_parallel_all_gather(hidden_states, - dim=0) - residual = tensor_model_parallel_all_gather(residual, dim=0) - - # for last layer of main model and mtp layer. - if self.enable_shared_expert_dp and self.layer_idx >= ( - self.layers - 1) and tp_size > 1: - hidden_states = get_tp_group().all_gather(hidden_states, 0) - residual = get_tp_group().all_gather(residual, 0) - - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None and isinstance(attn_metadata, dict): - attn_metadata = next(iter(attn_metadata.values()), None) - if attn_metadata is not None: - num_tokens = attn_metadata.num_actual_tokens - else: - num_tokens = hidden_states.shape[0] - - if num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:num_tokens] - residual = residual[:num_tokens] - - return hidden_states, residual - - -class TorchairDeepseekV2Model(nn.Module): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.tp_size = get_tensor_model_parallel_world_size() - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: TorchairDeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - ), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - residual, - kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata, - replace_allreduce=replace_allreduce) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class TorchairDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): - # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.num_dense_layers = self.config.first_k_dense_replace - self.num_moe_layers = self.config.num_hidden_layers - self.num_dense_layers - self.quant_config = quant_config - self.model = TorchairDeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # NOTE: This `load_weights` is mainly copied from - # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 - # to fix CI, and it is different from the implementation in main - # TODO: support eplb style load_weights - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - """""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = TorchairAscendFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "module" in name: - continue - - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - return_success=False) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v3.py b/vllm_ascend/torchair/models/torchair_deepseek_v3.py deleted file mode 100644 index aef8ae00..00000000 --- a/vllm_ascend/torchair/models/torchair_deepseek_v3.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from vllm_ascend.torchair.models.torchair_deepseek_v2 import \ - TorchairDeepseekV2ForCausalLM - - -class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM): - pass diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py deleted file mode 100644 index ed34c647..00000000 --- a/vllm_ascend/torchair/models/torchair_pangu_moe.py +++ /dev/null @@ -1,1116 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch_npu -from torch import nn -from torch.nn import Parameter -from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_tp_group, get_world_group) -from vllm.forward_context import get_forward_context -from vllm.logger import logger -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.models.utils import ( - extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors -from vllm.v1.sample.sampler import Sampler - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, - get_ascend_device_type) - -_ROUTER_SCALE = None - - -def use_h2p(): - # only use H2P when dp_size > 1. - if get_dp_group().world_size > 1: - return True - return False - - -# This class is adapted from vllm.model_executor.layers.linear.MergedColumnParallelLinear. -# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp). -class CustomMergedColumnParallelLinear(LinearBase): - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - ): - # Divide the weight matrix along the last dimension. - output_size = sum(output_sizes) - self.output_sizes = output_sizes - self.tp_size = get_tp_group().world_size - self.input_size_per_partition = input_size - self.output_size_per_partition = divide(output_size, self.tp_size) - self.output_partition_sizes = [self.output_size_per_partition] - # If QKV or MergedColumn, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = [ - divide(output_size, self.tp_size) - for output_size in self.output_sizes - ] - - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) - - self.gather_output = gather_output - - if output_sizes is None: - output_sizes = [output_size] - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=self.weight_loader) - if bias: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: int): - param_data = param.data - output_dim = getattr(param, "output_dim", None) - - assert loaded_shard_id < len(self.output_sizes) - - tp_rank = get_tp_group().rank_in_group - tp_size = get_tp_group().world_size - if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size - - is_sharded_weight = getattr(param, "is_sharded_weight", False) - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) - start_idx = tp_rank * shard_size - if not is_sharded_weight: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - else: - ignore_warning = getattr(param, "ignore_warning", False) - if not ignore_warning: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - def forward( - self, input_ - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - bias = self.bias if not self.skip_bias_add else None - - # Matrix multiply. - assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - if not self.return_bias: - return output - return output, output_bias - - -# This class is adapted from vllm.model_executor.layers.linear.RowParallelLinear. -# It is used to customize parallelism of certain linear(e.g., shared experts with all-rank tp) -# and detach communication to enable customized communication algorithms(e.g., H2P). -class CustomRowParallelLinear(LinearBase): - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - group=None, - ): - # Divide the weight matrix along the first dimension. - self.group = group if group is not None else get_tp_group() - self.tp_rank = self.group.rank_in_group - self.tp_size = self.group.world_size - self.input_size_per_partition = divide(input_size, self.tp_size) - self.output_size_per_partition = output_size - self.output_partition_sizes = [output_size] - - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias) - - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=self.weight_loader) - if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") - - if bias: - self.bias = Parameter( - torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = self.group.rank_in_group - input_dim = getattr(param, "input_dim", None) - is_sharded_weight = getattr(param, "is_sharded_weight", False) - is_sharded_weight = is_sharded_weight - - param_data = param.data - if input_dim is not None and not is_sharded_weight: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - def forward( - self, input_ - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - input_parallel = input_ - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output = self.quant_method.apply(self, input_parallel, bias=bias_) - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - -class PanguProMoEMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", - ) -> None: - super().__init__() - if not use_h2p(): - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - ) - else: - self.gate_up_proj = CustomMergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = CustomRowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - ) - - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -def topk_wrapper(num_voted_experts): - - def pangu_group8_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool = False, - num_expert_group: int = 0, - topk_group: int = 0, - global_num_experts: int = 0, - ): - scores = F.softmax(gating_output, dim=1) - num_tokens = scores.shape[0] - router_scale = _ROUTER_SCALE.squeeze( # type: ignore - ) - # TODO: support disable expert parallel - ep_size = get_ep_group().world_size - local_num_experts = global_num_experts // ep_size - local_num_group = topk // ep_size - experts_per_group = global_num_experts // topk - local_group_start = get_ep_group().rank_in_group * local_num_experts - local_group_end = (get_ep_group().rank_in_group + - 1) * local_num_experts - scores = F.softmax(gating_output, dim=1) - scores = scores[..., local_group_start:local_group_end] - - router_weights = router_scale[local_group_start:local_group_end] - - if num_voted_experts == 8: - # use original topk - topk_weights, topk_ids = torch.max(scores.view( - scores.shape[0], local_num_group, -1), - dim=-1) - bias = torch.arange(0, - local_num_experts, - experts_per_group, - device=scores.device, - dtype=torch.int32).unsqueeze(0) - topk_ids = topk_ids.to(torch.int32) + bias - - else: - group_expert_indices = torch.arange(experts_per_group, - dtype=torch.int32, - device=scores.device).view( - 1, 1, -1) - group_expert_offset = (torch.arange( - local_num_group, dtype=torch.int32, device=scores.device) * - experts_per_group).unsqueeze(0) - expert_index_range = torch.arange(experts_per_group, - dtype=torch.int32, - device=scores.device) - - scores_grouped = scores.view(num_tokens, local_num_group, - experts_per_group) - best_expert_idx = torch.argmax(scores_grouped, - dim=2) # (num_tokens, num_groups) - vote_mask = (best_expert_idx.unsqueeze(-1).to( - torch.int32) == group_expert_indices) - - expert_vote_freq = vote_mask.sum(dim=0) - - sorted_indices = torch.argsort(expert_vote_freq, - dim=1, - descending=True).to(torch.int32) - topk_experts = sorted_indices[:, :num_voted_experts] - keep_mask = (( - topk_experts.unsqueeze(-1) == expert_index_range).any( - dim=1)).unsqueeze(0) - - masked_scores = torch.where(keep_mask, scores_grouped, 0) - - topk_weights, best_pos_in_group = masked_scores.max(dim=2) - best_pos_in_group = best_pos_in_group.to(torch.int32) - topk_ids = (best_pos_in_group + group_expert_offset).to( - torch.int32) - - flatten_topk_ids = topk_ids.view(-1) - router_weights = router_weights.index_select(0, flatten_topk_ids).view( - topk_ids.shape) - topk_weights *= router_weights - - return topk_weights, topk_ids - - return pangu_group8_topk - - -class PanguProMoESparseMoeBlock(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_experts = config.num_experts - - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.num_experts_per_tok = config.num_experts_per_tok - self.router_scale = torch.nn.Parameter( - torch.ones((1, self.num_experts))) - - # on 300I Duo platform, we find that num_voted_experts set to 5 achieves - # good performance without sacrifice too much accuracy. for other platform, - # this is set to 8 to use original pangu grouped topk. - num_voted_experts = 5 if get_ascend_device_type( - ) == AscendDeviceType._310P else 8 - - self.experts = FusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - quant_config=quant_config, - custom_routing_function=topk_wrapper(num_voted_experts), - prefix=f"{prefix}.experts", - ) - self.use_ep = self.experts.use_ep - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - - if config.shared_expert_intermediate_size > 0: - self.shared_expert = PanguProMoEMLP( - hidden_size=config.hidden_size, - intermediate_size=config.shared_expert_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_expert", - ) - else: - self.shared_expert = None # type: ignore - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - global _ROUTER_SCALE - _ROUTER_SCALE = self.router_scale - - # TODO(angazenn): Does not support MC2 currently - get_forward_context().moe_comm_method_name = "allgathercommimpl" - - if not use_h2p(): - final_hidden_states = self.experts.forward_impl( - hidden_states=hidden_states, router_logits=router_logits) - else: - # TODO: when using h2p, we have to skip communication in vLLM - # native FusedMoE. here we need to design a better FusedMoE - # (maybe using AscendFusedMoE) to enable these different - # communication schema. - final_hidden_states = self.experts.quant_method.apply( - layer=self.experts, - x=hidden_states, - router_logits=router_logits, - top_k=self.experts.top_k, - renormalize=False, - use_grouped_topk=False, - global_num_experts=self.experts.global_num_experts, - expert_map=self.experts.expert_map, - custom_routing_function=self.experts.custom_routing_function, - apply_router_weight_on_input=self.experts. - apply_router_weight_on_input) - - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - if not use_h2p(): - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_dim) - - -class PanguProMoEAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_parameters: Dict[str, Any], - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - if use_h2p(): - self.o_proj = CustomRowParallelLinear(self.total_num_heads * - self.head_dim, - hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - group=get_tp_group()) - else: - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - rope_parameters=rope_parameters, - ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - ) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - if self.torchair_graph_enabled: - forward_kwargs = {} - output_shape = q.shape - attn_output = torch.empty(output_shape, - dtype=q.dtype, - device=q.device) - forward_kwargs['output'] = attn_output - attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache, - attn_metadata, - **forward_kwargs) - else: - attn_output = self.attn(q, k, v) - - output, _ = self.o_proj(attn_output) - return output - - -class PanguProMoEDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - - self.self_attn = PanguProMoEAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_parameters=config.rope_parameters, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - # `mlp_only_layers` in the config. - layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) - if (layer_idx not in mlp_only_layers) and (config.num_experts > 0): - self.mlp = PanguProMoESparseMoeBlock( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = PanguProMoEMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, - h2p_unpad_idx: Optional[torch.Tensor] = None, - h2p_pad_idx: Optional[torch.Tensor] = None, - is_start_layer: Optional[bool] = False, - ) -> torch.Tensor: - need_h2p_pad = h2p_unpad_idx is not None and h2p_pad_idx is not None \ - and h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0] - tp_size = get_tp_group().world_size - - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - if use_h2p(): - if is_start_layer: - if need_h2p_pad: - residual = residual.index_select(dim=0, index=h2p_pad_idx) - residual = torch.tensor_split( - residual, tp_size)[get_tp_group().rank_in_group] - else: - if tp_size > 1: - hidden_states = get_tp_group().all_gather(hidden_states, 0) - if need_h2p_pad: - hidden_states = hidden_states.index_select( - dim=0, index=h2p_unpad_idx) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - if use_h2p(): - if need_h2p_pad: - hidden_states = hidden_states.index_select(dim=0, - index=h2p_pad_idx) - if tp_size > 1: - hidden_states = dist._functional_collectives.reduce_scatter_tensor( - hidden_states, - "sum", - scatter_dim=0, - group=get_tp_group().device_group) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - if use_h2p(): - all_rank_group = get_world_group().device_group - output_size = (hidden_states.shape[0] * - get_world_group().world_size, - hidden_states.shape[1]) - # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=hidden_states.dtype, - device=hidden_states.device) - # All-gather. - dist.all_gather_into_tensor(output_tensor, - hidden_states, - group=all_rank_group) - hidden_states = output_tensor - - hidden_states = self.mlp(hidden_states, attn_metadata=attn_metadata) - - if use_h2p(): - hidden_states = dist._functional_collectives.reduce_scatter_tensor( - hidden_states, - "sum", - scatter_dim=0, - group=get_world_group().device_group) - - return hidden_states, residual - - -@support_torch_compile -class PanguProMoEModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: PanguProMoEDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - if use_h2p(): - # calculate necessary padding/unpadding idx before model forward. - - # the attn_metadata will be passed directly when use torchair. - # if attn_meatadata is not passed, we try to get it from forward_context. - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - - max_tokens_across_dp = get_forward_context().max_tokens_across_dp - - tp_size = get_tp_group().world_size - # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. - # we need pad it before if the shape can't be divided by group size. - # for h2p, we need pad it so that it can be divided by tp_size. - h2p_padded_len = ( - tp_size - (max_tokens_across_dp % tp_size) - ) % tp_size + max_tokens_across_dp - hidden_states.shape[0] - h2p_unpad_idx = torch.arange(hidden_states.shape[0], - device=hidden_states.device, - dtype=torch.int32) - h2p_pad_idx = torch.cat([ - h2p_unpad_idx, - torch.zeros(h2p_padded_len, - dtype=torch.int32, - device=hidden_states.device) - ]) - else: - h2p_unpad_idx = None - h2p_pad_idx = None - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, residual, - kv_caches[i - - self.start_layer] if kv_caches is not None else None, - attn_metadata, h2p_unpad_idx, h2p_pad_idx, - i == self.start_layer) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - if use_h2p(): - if get_tp_group().world_size > 1: - hidden_states = get_tp_group().all_gather(hidden_states, 0) - if h2p_unpad_idx.shape[0] < h2p_pad_idx.shape[0]: - hidden_states = hidden_states.index_select(dim=0, - index=h2p_unpad_idx) - return hidden_states - - -class PanguProMoEForCausalLM(nn.Module, SupportsPP): - - fall_back_to_pt_during_load = False - - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = PanguProMoEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.lm_head", - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata=None, # type: ignore - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata, # type: ignore - ): - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - tp_size = get_tp_group().world_size - tp_rank = get_tp_group().rank_in_group - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - - params_dict = dict(self.named_parameters()) # from model - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - # ======================================================= - # BF: add this to load with less layers - if 'layers' in name: - layer_idx = int(name.split('layers.')[-1].split('.')[0]) - if layer_idx >= self.model.end_layer: - continue - - if "rotary_emb.inv_freq" in name: - continue - - if "module" in name: - continue - - if name.endswith('kv_cache_offset'): - continue - - if name.endswith("k_proj.kv_cache_scale"): - remapped_kv_scale_name = name.replace( - "k_proj.kv_cache_scale", "attn.key_antiquant_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - loaded_weight = torch.tensor_split(loaded_weight, - tp_size, - dim=0)[tp_rank] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - if name.endswith("v_proj.kv_cache_scale"): - remapped_kv_scale_name = name.replace( - "v_proj.kv_cache_scale", "attn.value_antiquant_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - loaded_weight = torch.tensor_split(loaded_weight, - tp_size, - dim=0)[tp_rank] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if "mlp.experts" in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - # breakpoint() - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = param.weight_loader - # breakpoint() - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - if get_ascend_device_type( - ) == AscendDeviceType._310P and "head" in name: - # on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than - # ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented - # by linear, we manually cast the format here. - param.data = torch_npu.npu_format_cast(param.data, - ACL_FORMAT_FRACTAL_NZ) - return loaded_params diff --git a/vllm_ascend/torchair/ops/__init__.py b/vllm_ascend/torchair/ops/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vllm_ascend/torchair/ops/sequence_parallel.py b/vllm_ascend/torchair/ops/sequence_parallel.py deleted file mode 100644 index bfd327b4..00000000 --- a/vllm_ascend/torchair/ops/sequence_parallel.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -from torch.nn import functional as F -from vllm.distributed import (get_tensor_model_parallel_world_size, - get_tp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_reduce_scatter) -from vllm.forward_context import get_forward_context - -from vllm_ascend.platform import NPUPlatform - - -class MetadataForPadding: - - def __init__(self, - padding_flag=False, - lengths_sum_padding=0, - lengths_sum_unpadding=0, - pad_size=0, - not_dummy_and_is_prefill=False): - self.padding_flag = padding_flag - self.not_dummy_and_is_prefill = not_dummy_and_is_prefill - - self.lengths_sum_padding = lengths_sum_padding - self.lengths_sum_unpadding = lengths_sum_unpadding - self.pad_size = pad_size - - self.tp_size = get_tp_group().world_size - self.tp_rank_in_group = get_tp_group().rank_in_group - - assert self.lengths_sum_padding % self.tp_size == 0 - self.slice_size = self.lengths_sum_padding // self.tp_size - - self.mc2_mask = torch.zeros( - self.lengths_sum_padding, - dtype=torch.bool, - device=NPUPlatform.device_type, - ) - self.mc2_mask[:lengths_sum_unpadding] = True - - def padding_aligned_reduce_scatter(self, - data: torch.Tensor) -> torch.Tensor: - if self.padding_flag: - pad_size = self.pad_size - padded_data = F.pad(data, (0, 0, 0, pad_size)) - else: - padded_data = data - padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter( - padded_data, 0) - - return padded_data_reduce_scatter - - def allgather_unpadding_aligned(self, - padded_data: torch.Tensor) -> torch.Tensor: - padded_data_allgather = tensor_model_parallel_all_gather( - padded_data, 0) - if self.padding_flag: - lengths_sum_unpadding = self.lengths_sum_unpadding - unpadding_data = padded_data_allgather[:lengths_sum_unpadding] - else: - unpadding_data = padded_data_allgather - return unpadding_data - - def padding_slice(self, data: torch.Tensor) -> torch.Tensor: - - padded_data = F.pad(data, (0, 0, 0, self.pad_size)) - start = self.tp_rank_in_group * self.slice_size - end = start + self.slice_size - slice_data = padded_data[start:end] - - return slice_data - - def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor: - if self.padding_flag: - pad_size = self.pad_size - padded_data = F.pad(data, (0, 0, 0, pad_size)) - else: - padded_data = data - # padded_data = data - padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0) - - padded_data_reduce_scatter = padded_data[self.tp_rank_in_group] - - return padded_data_reduce_scatter - - -def init_metadata_for_sp(input_ids, enable_sequence_parallelism): - if not enable_sequence_parallelism: - return MetadataForPadding(padding_flag=False, - not_dummy_and_is_prefill=False) - - is_perifll = 0 - attn_metadata = get_forward_context().attn_metadata - tp_size = get_tensor_model_parallel_world_size() - if attn_metadata is not None: - if hasattr(attn_metadata, - 'is_only_prefill') and attn_metadata.is_only_prefill: - is_perifll = 1 - if hasattr(attn_metadata, - 'num_prefills') and attn_metadata.num_prefills > 0: - is_perifll = 1 - - if is_perifll: - lengths_sum_unpadding = input_ids.shape[0] - lengths_sum_padding = ( - (lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size - if lengths_sum_unpadding == lengths_sum_padding: - padding_flag = False - else: - padding_flag = True - pad_size = lengths_sum_padding - lengths_sum_unpadding - _metadata_for_padding = MetadataForPadding( - lengths_sum_unpadding=lengths_sum_unpadding, - lengths_sum_padding=lengths_sum_padding, - padding_flag=padding_flag, - pad_size=pad_size, - not_dummy_and_is_prefill=True) - - return _metadata_for_padding - - return MetadataForPadding(padding_flag=False, - not_dummy_and_is_prefill=False) diff --git a/vllm_ascend/torchair/ops/shared_weight_layer.py b/vllm_ascend/torchair/ops/shared_weight_layer.py deleted file mode 100644 index 6ab29af2..00000000 --- a/vllm_ascend/torchair/ops/shared_weight_layer.py +++ /dev/null @@ -1,245 +0,0 @@ -from dataclasses import dataclass -from typing import Callable, Optional - -import torch -import torch.distributed as dist -from vllm.distributed.parallel_state import GroupCoordinator -from vllm.model_executor.layers.linear import LinearBase - - -def dispose_tensor(x: torch.Tensor): - x.set_(torch.empty([], device=x.device, dtype=x.dtype)) - - -@dataclass -class LayerMetadata: - """Metadata for a layer. - """ - layer: Optional[LinearBase] # The layer object. - post_method: Callable[[ - torch.nn.Module - ], None] # The `process_weights_after_loading` method from the quant method. - weight: torch.Tensor # The weight tensor. - window_idx: int # The index of the window. - - -@dataclass -class SharedWindowMetadata: - """Metadata for a shared window. - """ - weight: torch.Tensor # The weight tensor to be shared by layers. - data_layer_idx: int # The index of the layer this window's weight is equal to. - work: Optional[torch.distributed.Work] # The asynchronous broadcast work. - - -@dataclass -class SeriesMetadata: - """Metadata for a weight shared series. - """ - group: GroupCoordinator - start_layer: int - end_layer: int - num_layers: int - prefetch_step: int - dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor. - layers: list[LayerMetadata] - shared_windows: list[ - SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. - window_offset: int # The index of the window for the next coming layer. - - def is_source(self, layer_idx) -> bool: - return layer_idx % self.group.world_size == self.group.rank_in_group - - def post_process_after_loading(self): - # This method only needs to be called once per series. - if self.shared_windows: - return - for layer_idx in range(self.start_layer, self.end_layer): - layer = self.layers[layer_idx - self.start_layer] - is_source = self.is_source(layer_idx) - # If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. - if not is_source: - layer.weight.set_(torch.empty_like(self.dummy_weight)) - # Broadcast to get the true weight. - dist.broadcast(layer.weight, - src=self.group.ranks[layer_idx % - self.group.world_size], - group=self.group.device_group) - assert layer.layer is not None - # Call `process_weights_after_loading` from the quant method. - layer.post_method(layer.layer) - step = layer_idx - self.start_layer - if step < self.prefetch_step: - # Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. - self.shared_windows.append( - SharedWindowMetadata( - weight=layer.weight.clone().detach(), - data_layer_idx=layer_idx, - work=None, - )) - layer.window_idx = step - # When the layer not intended to be stored in this device, link to the corresponding window's tensor. - if not is_source: - layer.weight.set_(self.shared_windows[-1].weight) - else: - # Build one more window for prefetch. The weight is useless, so just keep the shape. - if step == self.prefetch_step: - self.shared_windows.append( - SharedWindowMetadata( - weight=torch.empty_like(layer.weight), - data_layer_idx=-1, - work=None, - )) - # When the layer not intended to be stored in this device, dispose the tensor. - if not is_source: - dispose_tensor(layer.weight) - - dispose_tensor(self.dummy_weight) - - def reach_layer(self, layer_idx: int): - # The index of the layer to be prefetched. - next_layer_idx = (layer_idx + self.prefetch_step - ) % self.num_layers + self.start_layer - next_layer = self.layers[next_layer_idx - self.start_layer] - # The index of the window to store the weight for the coming layer. - next_layer.window_idx = self.window_offset - window = self.shared_windows[next_layer.window_idx] - # When the layer not intended to be stored in this device, link to the corresponding window's tensor. - if not self.is_source(next_layer_idx): - next_layer.weight.set_(window.weight) - # Update `window_offset` by rolling one step. - self.window_offset = (self.window_offset + 1) % (self.prefetch_step + - 1) - assert window.data_layer_idx != next_layer_idx - window.data_layer_idx = next_layer_idx - # Start asynchronous broadcast work. - window.work = dist.broadcast( - next_layer.weight, - src=self.group.ranks[next_layer_idx % self.group.world_size], - group=self.group.device_group, - async_op=True) - - def wait_weight(self, layer_idx: int): - # Find the asynchronous broadcast work and wait for it. - assert self.shared_windows - window = self.shared_windows[self.layers[layer_idx - - self.start_layer].window_idx] - # Make sure the data in the corresponding shared window is for the current layer. - assert window.data_layer_idx == layer_idx - if window.work is not None: - window.work.wait() - window.work = None - - -@dataclass -class LayerExternalMetadata: - """External metadata for a layer. - """ - series: SeriesMetadata - layer_idx: int - - -_series_dict: dict[str, SeriesMetadata] = {} - -_layer_external_dict: dict[int, LayerExternalMetadata] = {} - - -def _create_forward_wrapper(forward: Callable, series: SeriesMetadata, - layer_idx: int) -> Callable: - - def wrapped_forward(*args, **kwargs): - # Wait for the weight. - series.wait_weight(layer_idx) - return forward(*args, **kwargs) - - return wrapped_forward - - -""" -Register linear layers into a shared storage series. - -In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices. - -After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization. - -During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series. - -Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula: -- total_layers = end_layer - start_layer -- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer - -To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series. - -Arguments: - series_name: This name identifies which series this layer belongs to. - group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series. - start_layer: The index of the first layer in the series (inclusive). - end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer). - layer_idx: The index of the current layer. - layer: The linear layer object to register. - prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases. -""" - - -def register_layer_to_shared_weight_series( - series_name: str, - group: GroupCoordinator, - start_layer: int, - end_layer: int, - layer_idx: int, - layer: LinearBase, - prefetch_step: int = 1, -): - global _series_dict - if series_name not in _series_dict: - num_layers = end_layer - start_layer - assert num_layers > 0 - assert prefetch_step >= 0 and prefetch_step <= num_layers - 2 - _series_dict[series_name] = SeriesMetadata( - group=group, - start_layer=start_layer, - end_layer=end_layer, - num_layers=num_layers, - prefetch_step=prefetch_step, - dummy_weight=torch.empty_like(layer.weight), - layers=[ - LayerMetadata( - layer=None, - post_method=lambda layer: None, - weight=torch.empty([]), - window_idx=-1, - ) for _ in range(num_layers) - ], - shared_windows=[], - window_offset=prefetch_step, - ) - series = _series_dict[series_name] - assert layer.quant_method is not None - series.layers[layer_idx - start_layer] = LayerMetadata( - layer=layer, - post_method=layer.quant_method.process_weights_after_loading, - weight=layer.weight, - window_idx=-1, - ) - # Discard the original `process_weights_after_loading` method such that it won't be called by others. - layer.quant_method.process_weights_after_loading = lambda layer: None - # When the layer not intended to be stored in this device, dispose the tensor and skip weight loading. - if not series.is_source(layer_idx): - dispose_tensor(layer.weight) - layer.weight.weight_loader = lambda *args, **kwargs: None - layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx) - global _layer_external_dict - _layer_external_dict[id(layer)] = LayerExternalMetadata( - series=series, - layer_idx=layer_idx, - ) - - -def post_process_after_loading_for_shared_weight_series(layer: LinearBase): - ext = _layer_external_dict[id(layer)] - ext.series.post_process_after_loading() - - -def reach_layer_for_shared_weight_series(layer: LinearBase): - ext = _layer_external_dict[id(layer)] - ext.series.reach_layer(ext.layer_idx) diff --git a/vllm_ascend/torchair/ops/torchair_activation.py b/vllm_ascend/torchair/ops/torchair_activation.py deleted file mode 100644 index 0089b663..00000000 --- a/vllm_ascend/torchair/ops/torchair_activation.py +++ /dev/null @@ -1,37 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import torch - - -def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: - """AscendSiluAndMul forward in torchair mode. - - The key difference from the original implementation is the removal of operators - from the torch.ops.vllm class, as these operators only function in non-torchair - modes. Adding them back would cause the graph compilation to fail. - """ - - import torch_npu - - from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - - if get_ascend_device_type() == AscendDeviceType._310P: - out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16) - else: - out = torch_npu.npu_swiglu(x) - return out diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py deleted file mode 100644 index 5892d612..00000000 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ /dev/null @@ -1,1436 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/kernels/test_moe.py - -import os -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch_npu -from torch import nn -from vllm.config import get_current_vllm_config -from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_tp_group) -from vllm.forward_context import get_forward_context -from vllm.logger import logger -from vllm.model_executor.layers.fused_moe.config import \ - FusedMoEConfig # isort: skip -from vllm.model_executor.layers.fused_moe.config import \ - FusedMoEParallelConfig # isort: skip -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map, - get_compressed_expert_map) -from vllm.model_executor.layers.quantization.base_config import \ - QuantizationConfig - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map -from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, - AscendQuantConfig) -from vllm_ascend.quantization.utils import get_quant_method -from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.torchair.utils import (get_all_reduce_merge_state, - get_rm_router_logits_state, - npu_stream_switch, npu_wait_tensor, - super_kernel) -from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, - get_ascend_device_type, - is_hierarchical_communication_enabled) - - -def torchair_fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - moe_parallel_config: FusedMoEParallelConfig, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None, - is_torchair: bool = False, - mc2_mask: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - quant_mode = 0 - ep_rank_id = moe_parallel_config.ep_rank - ep_world_size = moe_parallel_config.ep_size - - # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93 - or is_torchair) - - # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93 - # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and - # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly - # improve communication performance. - need_expert_scale = is_hierarchical_communication_enabled() - - enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - - moe_expert_num = len(expert_map) - kwargs_mc2 = { - "x": hidden_states, - "expert_ids": topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - - stage1_kwargs = { - "scales": None, - "quant_mode": quant_mode, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) - if need_expert_scale: - stage1_kwargs.update({ - "expert_scales": topk_weights.to(torch.float32), - }) - - kwargs_mc2.update(stage1_kwargs) - - output = torch_npu.npu_moe_distribute_dispatch_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( - **kwargs_mc2) - # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ - ep_recv_counts, _, expand_scales = output[0:7] - - if shared_experts is not None: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - npu_wait_tensor(shared_gate_up, expand_x) - shared_act = shared_experts.act_fn(shared_gate_up) - - w1 = w1.transpose(1, 2) - - group_list = expert_token_nums.to(torch.int64) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[expand_x], - weight=[w1], - split_item=2, - # 1 means count mode, to avoid cumulative operation of the group list - group_list_type=1, - group_type=0, - group_list=group_list, - )[0] - - gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) - - w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=1, - group_type=0, - group_list=group_list, - )[0] - - # moeCombine - kwargs_mc2 = { - "expand_x": down_out_list, - "expert_ids": topk_ids, - "expert_scales": topk_weights.to(torch.float32), - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - tp_recv_counts = output[5] - stage3_kwargs = { - "ep_send_counts": ep_recv_counts, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - "expand_scales": expand_scales, - } - if enable_dispatch_v2: - stage3_kwargs.update({ - "assist_info_for_combine": - assist_info_for_combine, - }) - else: - stage3_kwargs.update({ - "expand_idx": assist_info_for_combine, - }) - if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage3_kwargs.update({ - "x_active_mask": mc2_mask, - }) - kwargs_mc2.update(stage3_kwargs) - - hidden_states = torch_npu.npu_moe_distribute_combine_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( - **kwargs_mc2) - - if shared_experts is None: - return hidden_states - else: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act, down_out_list) - shared_hidden_states, _ = shared_experts.down_proj(shared_act) - return hidden_states, shared_hidden_states - - -def torchair_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1, -) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - - Args: - hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - - Returns: - hidden_states: output hidden states after MLP. - """ - - w1 = w1.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - - hidden_states = torch_npu.npu_swiglu(hidden_states) - - w2 = w2.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - - return hidden_states - - -# currently expert parallelism implemented with all2all -# is under-optimized. -def torchair_fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, -): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - device = hidden_states.device - - if expert_map is not None: - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) - - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) - - hidden_states = hidden_states[sorted_idx] - else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - - w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - )[0] - - hidden_states = torch_npu.npu_swiglu(gate_up_out_list) - - w2 = w2.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - )[0] - - if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) - hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - else: - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - -def torchair_fused_experts_moge( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - moe_parallel_config: FusedMoEParallelConfig, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - ep_size = moe_parallel_config.ep_size - local_num_experts = global_num_experts // ep_size - local_num_group = top_k // ep_size - - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - - bsz, _ = hidden_states.shape - flatten_topk_ids = topk_ids.view(-1) - sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) - sorted_topk_ids = sorted_topk_ids.to(torch.int32) - sorted_hidden_states = hidden_states.index_select( - 0, sorted_topk_ids // local_num_group) - - experts_id = torch.arange(0, - local_num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( - torch.float32).sum(0) - topk_scales = topk_weights.view(-1).index_select( - 0, sorted_topk_ids).unsqueeze(-1) - group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - - w1 = w1.transpose(1, 2) - gate_up_out = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - if get_ascend_device_type() == AscendDeviceType._310P: - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) - gate_up_out *= topk_scales - - w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) - unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) - final_hidden_states = unsorted_hidden_states.reshape( - bsz, top_k // ep_size, -1).sum(1) - - return final_hidden_states - - -def torchair_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, - max_num_tokens: Optional[int] = None, -) -> torch.Tensor: - """ - Fused experts with top-k routing. - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - """ - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - """ - # if torch.distributed.get_rank() == 0: - # print(w1.shape) - # print(hidden_states.shape) - - original_shape = hidden_states.shape - # assert len(original_shape) == 2 - - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - dtype = hidden_states.dtype - device = hidden_states.device - # assert dtype in [torch.float32, torch.float16, torch.bfloat16 - # ], "Only float32, float16, and bfloat16 are supported" - - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - - if expert_map is not None: - # Generate token indices and flatten - token_indices = (torch.arange(num_tokens, - device=device, - dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - filtered_weights = torch.where( - mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) - filtered_experts = torch.where( - mask, local_experts_flat, - torch.full_like(local_experts_flat, - num_experts)).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - sorted_token_indices = token_indices[sort_indices] - sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - token_counts = token_counts[:num_experts] - expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) - - # Rearrange hidden_states - sorted_hidden_states = hidden_states[sorted_token_indices] - else: - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - active_num = max_num_tokens if max_num_tokens is not None else num_tokens - sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=active_num) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - - w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - )[0] - - gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) - - w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - )[0] - - if expert_map is not None: - weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros(*original_shape, - device=hidden_states.device, - dtype=dtype) - - # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # This created multiple NaN and index_add_ will mix them up which harms accuracy - # remove this mask and filter after it being fixed - num_valid_tokens = mask.sum() - valid_token_mask = torch.arange( - 0, sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens - valid_output = torch.where( - valid_token_mask, weighted_down_out, - torch.zeros_like(weighted_down_out)).to(dtype) - final_hidden_states.index_add_(0, sorted_token_indices, valid_output) - else: - scales = torch.ones_like( - topk_weights) if apply_router_weight_on_input else topk_weights - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - down_out_list, - skip1=None, - skip2=None, - bias=None, - scales=scales, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - - return final_hidden_states - - -def torchair_native_grouped_topk( - topk_weights: torch.Tensor, - num_expert_group: Optional[int], - topk_group: Optional[int], -): - topk_group = 0 if topk_group is None else topk_group - num_expert_group = 0 if num_expert_group is None else num_expert_group - - num_token = topk_weights.shape[0] - grouped_weights = topk_weights.view(num_token, num_expert_group, - -1).max(dim=-1).values - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), - k=topk_group, - dim=-1, - sorted=False)[1] - topk_group_mask = torch.zeros_like(grouped_weights) - topk_group_mask.scatter_(1, topk_group_indices, 1) - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) - topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) - - return topk_weights - - -def torchair_select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: Optional[torch.Tensor] = None -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Select top-k experts based on router logits. - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - router_logits: Router logits of shape (num_tokens, num_experts). - top_k: Number of experts to select. - use_grouped_topk: Whether to group experts before selecting top-k. - renormalize: Whether to renormalize the routing weights. - topk_group: Number of expert groups to select from. - num_expert_group: Number of experts in each group. - custom_routing_function: Custom routing function. - scoring_func: Scoring function to use. - e_score_correction_bias: Correction bias to apply to expert scores. - - Returns: - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - - Raises: - ValueError: If an unsupported scoring function is provided. - """ - - def _renormalize_topk_weights( - topk_weights: torch.Tensor, - renormalize: bool, - ): - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, - keepdim=True) - return topk_weights - - if scoring_func == "softmax": - # NOTE: vLLM use dtype=torch.float here - if not use_grouped_topk and custom_routing_function is None: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids - - topk_weights = router_logits.softmax(dim=-1) - elif scoring_func == "sigmoid": - topk_weights = router_logits.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_weights = topk_weights - topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) - - # TODO: Change to npu_group_topk when the latest CANN and NNAL is available - # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) - topk_weights = torchair_native_grouped_topk(topk_weights, - num_expert_group, - topk_group) - # TODO bfloat16 is not supported in torch.topk with ge graph. - if e_score_correction_bias is not None: - topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_weights.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False) - topk_ids = topk_ids.to(torch.int32) - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids - - if custom_routing_function is not None: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - global_num_experts=global_num_experts) - # Required by npu_moe_init_routing - topk_ids = topk_ids.to(torch.int32) - return topk_weights, topk_ids - - topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) - topk_weights = topk_weights.to(hidden_states.dtype) - - # Required by npu_moe_init_routing - topk_ids = topk_ids.to(torch.int32) - topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - - return topk_weights, topk_ids - - -class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): - - def __init__(self, moe: FusedMoEConfig = None): - - super().__init__(moe=moe) - vllm_config = get_current_vllm_config() - - self.global_batch_size = vllm_config.scheduler_config.max_num_seqs - self.max_model_len = vllm_config.model_config.max_model_len - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - try: - device_group = get_mc2_group().device_group - # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=device_group) - backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) - except AttributeError: - self.moe_all_to_all_group_name = None - - def process_weights_after_loading(self, layer): - super(UnquantizedFusedMoEMethod, - self).process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w13_weight.data), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w2_weight.data), - requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = False, - enable_force_load_balance: bool = False, - shared_experts: Optional[Any] = None, - **kwargs, - ) -> torch.Tensor: - global_redundant_expert_num = get_ascend_config( - ).init_redundancy_expert - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - - topk_weights = topk_weights.to(x.dtype) - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - - fused_moe_state = get_forward_context().fused_moe_state - if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: - fused_moe_state = FusedMoEState.All2All - - if fused_moe_state == FusedMoEState.MC2: - return torchair_fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - moe_parallel_config=self.moe.moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled, - mc2_mask=kwargs.get("mc2_mask", None)) - elif fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.NaiveMulticast - ]: - return torchair_fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) - else: - return torchair_fused_experts_with_all2all( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group()) - - -class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod): - - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any]): - self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "moe", - packed_modules_mapping) - - -class TorchairAscendFusedMoE(FusedMoE): - - # The moe_counter parameter is required during the initialization of EPLB - # to identify the current layer index within the MOE model. - moe_counter = -1 - - def __init__( - self, - num_experts: int, # Global number of experts - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - ep_size: Optional[int] = None, - dp_size: Optional[int] = None, - prefix: str = "", - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - ): - # TODO: This could not initialize FusedMoE baseclass, - # fixme and make __init__() of AscendFusedMoE more clear - super().__init__( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=reduce_results, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - quant_config=quant_config, - tp_size=tp_size, - ep_size=ep_size, - dp_size=dp_size, - pcp_size=1, - prefix=prefix, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - activation=activation, - ) - TorchairAscendFusedMoE.moe_counter += 1 - self.moe_instance_id = TorchairAscendFusedMoE.moe_counter - self.prefix = prefix - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - vllm_config = get_current_vllm_config() - - self.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - # TODO: support pcp - pcp_size_=1, - dp_size_=(dp_size - if dp_size is not None else get_dp_group().world_size), - vllm_parallel_config=vllm_config.parallel_config) - - self.top_k = top_k - self.num_experts = num_experts - self.global_num_experts = num_experts - assert intermediate_size % self.tp_size == 0 - self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.reduce_results = reduce_results - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.custom_routing_function = custom_routing_function - self.scoring_func = scoring_func - self.e_score_correction_bias = e_score_correction_bias - self.expert_map = None - self.activation = activation - self.log2phy = None - self.global_redundant_expert_num = 0 - - is_deepseek_v3_r1 = self.global_num_experts == 256 - self.rm_router_logits = get_rm_router_logits_state( - self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1) - self.all_reduce_merge = get_all_reduce_merge_state( - self.moe_parallel_config.ep_size, is_deepseek_v3_r1) - - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - self.expert_map_path = ascend_config.expert_map_path - self.global_redundant_expert_num = ascend_config.init_redundancy_expert - self.global_num_experts = num_experts + self.global_redundant_expert_num - # static eplb initializing with expert_map_path - if self.expert_map_path and os.path.exists( - self.expert_map_path) and os.access(self.expert_map_path, - os.R_OK): - self.expert_load_balancer = ExpertLoadBalancer( - self.expert_map_path, num_experts) - self.expert_load_balancer.check_expert_map_tensor() - self.global_redundant_expert_num = ( - self.expert_load_balancer.get_global_redundant_expert_num()) - try: - self.local_num_experts, self.expert_map = ( - self.expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, self.ep_rank)) - self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, self.ep_rank).npu() - self.global_num_experts = num_experts + self.global_redundant_expert_num - except Exception as e: - logger.warning( - f"Init expert map of mtp/eagle when using sample.{e}") - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) - self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank).npu() - if self.expert_map is not None and isinstance( - self.expert_map, torch.Tensor): - logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" - " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, - self.global_num_experts, - get_compressed_expert_map(self.expert_map)) - else: - # init moe. - self.local_num_experts, self.expert_map, _ = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) - # dynamic eplb initializing with not expert_map_path - if self.dynamic_eplb: - self.log2phy = determine_default_log2phy_map( - self.global_num_experts, self.ep_size, self.ep_rank).npu() - if self.expert_map is not None and isinstance( - self.expert_map, torch.Tensor): - logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" - " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, - self.global_num_experts, - get_compressed_expert_map(self.expert_map)) - local_num_experts = (torch.sum(self.expert_map != -1) - if self.expert_map is not None else num_experts) - if self.dynamic_eplb: - self.moe_load = torch.zeros(local_num_experts, - dtype=torch.int64).npu() - - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.multistream_overlap_shared_expert = \ - ascend_config.multistream_overlap_shared_expert and \ - self.torchair_graph_enabled - self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError("Only softmax scoring function is supported for " - "non-grouped topk.") - self.moe = FusedMoEConfig( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=params_dtype, - ) - if quant_config is None: - self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( - self.moe) - else: - if quant_config.is_layer_skipped_ascend( - prefix, quant_config.packed_modules_mapping): - self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( - self.moe) - else: - self.quant_method = TorchairAscendFusedMoEMethod( - quant_config, prefix, quant_config.packed_modules_mapping) - - assert self.quant_method is not None - - self.moe_load = None - local_num_experts = (torch.sum(self.expert_map != -1) - if self.expert_map is not None else num_experts) - if self.dynamic_eplb: - self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) - - moe_quant_params = { - "num_experts": local_num_experts, - "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, - "params_dtype": params_dtype, - "weight_loader": self.weight_loader, - } - # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): - moe_quant_params["intermediate_size_full"] = intermediate_size - - self.ep_group = get_ep_group() - # NOTE: self.tp_group is not expert_tp_group - self.tp_group = get_tp_group().device_group - self.quant_method.create_weights(layer=self, **moe_quant_params) - - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(self.dp_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - return buffer - - def forward(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - enable_force_load_balance: bool = False, - top_k: Optional[int] = None, - shared_experts: Optional[Any] = None, - gate=None, - replace_allreduce: bool = False, - _metadata_for_padding: Optional[MetadataForPadding] = None): - - assert self.quant_method is not None - - if top_k: - real_top_k = top_k - else: - real_top_k = self.top_k - - num_tokens, hidden_size = hidden_states.shape - - forward_context = get_forward_context() - fused_moe_state = forward_context.fused_moe_state - mc2_mask = forward_context.mc2_mask - if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: - fused_moe_state = FusedMoEState.All2All - # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. - quantized_x_for_share, dynamic_scale_for_share = None, None - from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ - TorchairAscendW8A8DynamicFusedMoEMethod - running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2 - - if self.multistream_overlap_shared_expert: - with super_kernel(self.prefix, - "stream-fusion=1", - enabled=running_in_super_kernel): - if not self.rm_router_logits: - if self.enable_super_kernel: - router_logits, _ = gate(hidden_states.float()) - else: - router_logits, _ = gate(hidden_states) - if hasattr(self.quant_method, "quant_method") and \ - isinstance(self.quant_method.quant_method, - TorchairAscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( - hidden_states) - - if shared_experts: - if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2: - # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce - shared_hidden_states = shared_experts(hidden_states) - - mc2_mask = forward_context.mc2_mask - - enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill - tp_size = get_tensor_model_parallel_world_size() - if enable_sp: - tp_rank = get_tensor_model_parallel_rank() - mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask - chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0) - mc2_mask = chunk_mc2_mask[tp_rank] - replace_allreduce = True - - if (fused_moe_state not in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ]): - if tp_size > 1: - tp_rank = get_tensor_model_parallel_rank() - chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) - mc2_mask = chunk_mc2_mask[tp_rank] - if not replace_allreduce: - if fused_moe_state in {FusedMoEState.MC2}: - padding_size = forward_context.padded_num_tokens - else: - # TODO: Determine if we can remove the padding - padding_size = tp_size - if num_tokens < padding_size and not self.enable_shared_expert_dp: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, padding_size - num_tokens)) - router_logits = nn.functional.pad( - router_logits, (0, 0, 0, padding_size - num_tokens)) - if tp_size > 1: - tp_rank = get_tensor_model_parallel_rank() - if not self.enable_shared_expert_dp: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] - - if self.dp_size > 1: - if fused_moe_state == FusedMoEState.AllGather: - # NOTE: When in torchair graph, it has been padded in model_runner_v1 - if not self.torchair_graph_enabled: - max_tokens_across_dp = forward_context.max_tokens_across_dp - if num_tokens < max_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_tokens_across_dp - num_tokens)) - if not self.rm_router_logits: - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_tokens_across_dp - num_tokens)) - hidden_states = get_dp_group().all_gather(hidden_states, 0) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = get_dp_group().all_gather(router_logits, 0) - - elif fused_moe_state == FusedMoEState.NaiveMulticast: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_sp(1) - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - if self.rm_router_logits: - router_logits, _ = gate(hidden_states) - else: - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_cpu) - - # Matrix multiply. - e_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=real_top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance, - log2phy=self.log2phy, - global_redundant_expert_num=self.global_redundant_expert_num, - shared_experts=shared_experts if self.torchair_graph_enabled - and self.multistream_overlap_shared_expert and not is_prefill else - None, - mc2_mask=mc2_mask, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, - prefix=self.prefix, - running_in_super_kernel=running_in_super_kernel, - ) - - if shared_experts: - if isinstance(e_hidden_states, - tuple) and len(e_hidden_states) == 2: - e_hidden_states, shared_hidden_states = e_hidden_states - - if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 4: - e_hidden_states, shared_hidden_states, group_list_type, expert_tokens = e_hidden_states - if self.dynamic_eplb: - self.moe_load += expert_tokens if group_list_type else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) - - if shared_experts is None and isinstance( - e_hidden_states, tuple) and len(e_hidden_states) == 3: - e_hidden_states, group_list_type, expert_tokens = e_hidden_states - if self.dynamic_eplb: - self.moe_load += expert_tokens if group_list_type else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) - - if (fused_moe_state not in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ] and not replace_allreduce and not self.enable_shared_expert_dp): - if tp_size > 1: - if isinstance(e_hidden_states, tuple): - e_hidden_states = e_hidden_states[0] - dist.all_gather(list(chunk_hidden_states), e_hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - dispose_tensor(e_hidden_states) - else: - final_hidden_states = e_hidden_states - if num_tokens < padding_size: - final_hidden_states = final_hidden_states[:num_tokens] - elif self.dp_size > 1 and not self.enable_shared_expert_dp: - if fused_moe_state == FusedMoEState.NaiveMulticast: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - final_hidden_states = get_dp_group().all_reduce( - e_hidden_states) - final_hidden_states = final_hidden_states[start:end, :] - dispose_tensor(e_hidden_states) - elif fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = get_dp_group().reduce_scatter( - e_hidden_states, 0) - final_hidden_states = final_hidden_states[:num_tokens] - dispose_tensor(e_hidden_states) - else: - final_hidden_states = e_hidden_states - else: - final_hidden_states = e_hidden_states - - if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.AllGatherEP, - FusedMoEState.NaiveMulticast - ]: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - if shared_experts: - return final_hidden_states, shared_hidden_states - else: - return final_hidden_states - - def update_expert_map(self, new_expert_map): - self.expert_map = new_expert_map - - def get_map(self): - return self.expert_map - - def get_log2phy_map(self): - return self.log2phy - - def clear_moe_load(self): - if self.moe_load is not None: - self.moe_load.zero_() - - # ----------------------------------------- TBO-related -------------------------------------------- - - def _forward_ms_fused_moe_comp( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - real_top_k, - enable_force_load_balance: bool = False, - ): - hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=real_top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - e_score_correction_bias=self.e_score_correction_bias, - is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance, - ) - - return hidden_states diff --git a/vllm_ascend/torchair/ops/torchair_layernorm.py b/vllm_ascend/torchair/ops/torchair_layernorm.py deleted file mode 100644 index 3a3146b8..00000000 --- a/vllm_ascend/torchair/ops/torchair_layernorm.py +++ /dev/null @@ -1,78 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -from typing import Optional, Tuple, Union - -import torch -from vllm.config import get_current_vllm_config -from vllm.model_executor.layers.layernorm import RMSNorm - -_original_re_init = RMSNorm.__init__ - - -def torchair_rmsnorm_init_( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - dtype: Optional[torch.dtype] = None, -) -> None: - _original_re_init(self, hidden_size, eps, var_hidden_size, has_weight, - dtype) - vllm_config = get_current_vllm_config() - self.bias = None - # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config.quant_config is not None and \ - any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - - -def torchair_rmsnorm_forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """AscendRMSNorm forward in torchair mode. - - The key difference from the original implementation is the removal of operators - from the torch.ops.vllm class, as these operators only function in non-torchair - modes. Adding them back would cause the graph compilation to fail. - """ - - import torch_npu - - from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - if residual is not None: - if get_ascend_device_type() == AscendDeviceType._310P: - orig_dtype = residual.dtype - x = x + residual.to(x.dtype) - residual = x.to(orig_dtype) - x, _ = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) - if self.bias is not None: - x.add_(self.bias) - return x, residual - - x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) - if self.bias is not None: - x.add_(self.bias) - return x diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py deleted file mode 100644 index 9fdb231b..00000000 --- a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py +++ /dev/null @@ -1,367 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import math -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -import torch_npu -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type) - - -def custom_rotary_embedding_enabled(query, neox_style, head_size): - return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( - ) - - -def rope_forward_oot( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None, - is_qwen_torchair: Optional[bool] = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - if get_ascend_config( - ).torchair_graph_config.enabled and not is_qwen_torchair: - return self.forward_native( - positions, - query, - key, - offsets, - ) - - query_shape, key_shape = query.shape, key.shape - if self.cos_sin_cache.device != query.device: - self.cos_sin_cache = self.cos_sin_cache.to(query.device) - if self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) - neox_style = self.is_neox_style - if is_neox_style_override is not None: - neox_style = is_neox_style_override - # adopt custom kernel path for rotary_embedding - if custom_rotary_embedding_enabled( - query, neox_style, self.head_size) and get_ascend_device_type( - ) != AscendDeviceType._310P: - query, key = torch.ops._C_ascend.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - neox_style, - ) - return query.view(query_shape), key.view(key_shape) - if offsets is not None: - raise NotImplementedError( - "Batched rotary embedding is currently not supported on NPU.") - else: - # TODO: Remove the contiguous in the future. - query = query.contiguous().view(query.shape[0], -1) - key = key.contiguous().view(key.shape[0], -1) - torch_npu._npu_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - neox_style, - ) - return query.view(query_shape), key.view(key_shape) - - -def native_rope_deepseek_forward(self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None): - if len(key.shape) == 2: - key = key[:, None, :] - # Note: we implement the non neox_style method with shuffle the last dim and neox style - # calculation method which is also more compute friendly to the ascend machine - # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py - neox_style = True - if self.is_neox_style is False: - b, h_q, d = query.shape - query = query.view(b, h_q, d // 2, 2).transpose(3, - 2).reshape(b, h_q, d) - b, h_k, d = key.shape - key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, - neox_style) - return q_pe, k_pe - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations, - dim, - base=10000, - max_position_embeddings=2048): - # Note: use torch instead of math to solve MTP compilation error. - return (dim * torch.log( - torch.tensor(max_position_embeddings) / - (num_rotations * 2 * torch.pi))) / (2 * torch.log(torch.tensor(base))) - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -# Find dim range bounds based on rotations -def yarn_find_correction_range(low_rot, - high_rot, - dim, - base=10000, - max_position_embeddings=2048): - # Note: use torch instead of math to solve MTP compilation error. - low = torch.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = torch.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - # Note: use torch instead of max/min to solve MTP compilation error. - return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1) - - -def yarn_linear_ramp_mask(min_value, max_value, dim): - # Note: The if conditional branch is not used here - # to solve MTP compilation error. - max_value += (min_value == max_value).float() * 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - - min_value) / (max_value - min_value) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids] - sin = sin[position_ids] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] - - if len(q.shape) == 3: - q = q[:, :, None, :] - if len(k.shape) == 2: - k = k[:, None, None, :] - elif len(k.shape) == 3: - k = k[:, :, None, :] - - b, h_q, s, d = q.shape - q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d) - - b, h_k, s, d = k.shape - k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = q_embed.view(b, h_q, d) - k_embed = k_embed.view(b, h_k, d) - - return q_embed, k_embed - - -def _set_cos_sin_cache(self, max_seq_len, device, dtype): - dim = self.rotary_dim - - freq_extra = 1.0 / (self.base**( - torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - freq_inter = 1.0 / (self.scaling_factor * self.base**( - torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( - device=device, dtype=torch.float32) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(max_seq_len, device=device, dtype=torch.float32) - - freqs = torch.outer(t, inv_freq) - cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale - sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale - cos_cached = cos_cached.to(dtype) - sin_cached = sin_cached.to(dtype) - cache = torch.cat([freqs.cos() * self.mscale, - freqs.sin() * self.mscale], - dim=-1).to(dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) - self.register_buffer("cos_cached", cos_cached, persistent=False) - self.register_buffer("sin_cached", sin_cached, persistent=False) - - -def __set_cos_sin_cache(self, seq_len, device, dtype): - inv_freq = 1.0 / (self.base**(torch.arange( - 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * - (1 / self.rotary_dim))) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) - self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) - self.embed = F.embedding - - -_original_re_init = RotaryEmbedding.__init__ - - -def qwen_rope_init_func( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - dtype: torch.dtype, -) -> None: - _original_re_init(self, head_size, rotary_dim, max_position_embeddings, - base, is_neox_style, dtype) - if get_ascend_config().torchair_graph_config.enabled: - __set_cos_sin_cache(self, - seq_len=max_position_embeddings, - device="npu", - dtype=dtype) - - -def rope_forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None, - max_seq_len: Optional[int] = None, - is_prefill: Optional[bool] = True, - is_qwen_torchair: Optional[bool] = False, -): - if get_ascend_config().torchair_graph_config.enabled \ - and is_qwen_torchair and not is_prefill: - if max_seq_len is not None and torch.gt(max_seq_len, - self.max_position_embeddings): - __set_cos_sin_cache(self, - seq_len=max_seq_len, - device=query.device, - dtype=torch.float32) - - # bsnd/bnsd - if positions is not None: - cos = self.embed(positions, self.cos) - sin = self.embed(positions, self.sin) - self.cos_embed = cos - self.sin_embed = sin - else: - cos = self.cos_embed - sin = self.sin_embed - - query = query.view(*query.shape[:-1], -1, self.head_size).contiguous() - key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() - - cos = cos.unsqueeze(-2).unsqueeze(-2) - sin = sin.unsqueeze(-2).unsqueeze(-2) - - query = query.unsqueeze(1) - key = key.unsqueeze(1) - - q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( - query, key, cos, sin) - return q_embed.flatten(-2), k_embed.flatten(-2) - else: - return rope_forward_oot(self, positions, query, key, offsets, - is_neox_style_override, - is_qwen_torchair) # type: ignore - - -def deepseek_rope_init_func( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - mscale: float = 1, - mscale_all_dim: float = 0, -) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation. - self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) / - yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * - attn_factor) - super(DeepseekScalingRotaryEmbedding, - self).__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) - - # NOTE: For ascend friendly computing, reorder sin and cos cache - self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor) - _set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu") diff --git a/vllm_ascend/torchair/ops/torchair_vocab_parallel_embedding.py b/vllm_ascend/torchair/ops/torchair_vocab_parallel_embedding.py deleted file mode 100644 index f83f2bca..00000000 --- a/vllm_ascend/torchair/ops/torchair_vocab_parallel_embedding.py +++ /dev/null @@ -1,38 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from vllm.distributed import tensor_model_parallel_all_reduce - - -def vocab_embedding_forward(self, input_): - if self.tp_size > 1: - # Build the mask. - masked_input, input_mask = self._get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, - self.shard_indices.org_vocab_end_index, - self.shard_indices.num_org_vocab_padding, - self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) - else: - masked_input = input_ - # Get the embeddings. - output_parallel = self.quant_method.embedding(self, masked_input.long()) - # Mask the output embedding. - if self.tp_size > 1: - output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - return output diff --git a/vllm_ascend/torchair/quantization/__init__.py b/vllm_ascend/torchair/quantization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py deleted file mode 100644 index c61ddf32..00000000 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ /dev/null @@ -1,501 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any, Callable, Dict, Optional - -import numpy as np -import torch -import torch_npu -from vllm.config import get_current_vllm_config -from vllm.distributed import get_ep_group -from vllm.forward_context import get_forward_context - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( - torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor - - -class TorchairAscendW4A8DynamicLinearMethod: - """Linear method for Ascend W4A8_DYNAMIC - """ - - def __init__(self): - self.transpose_weight = True - - vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 256) - quant_version = vllm_config.quant_config.quant_description.get( - "version", "0") - self.new_quant_version = quant_version == "1.0.0" - - from vllm.distributed import get_tensor_model_parallel_world_size - self.tp_size = get_tensor_model_parallel_world_size() - - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = {} - - if self.new_quant_version: - pack_factor = 2 - actual_output_size = output_size // pack_factor - params_dict["weight"] = torch.empty(actual_output_size, - input_size, - dtype=torch.int8) - params_dict["_packed_dim"] = 0 - params_dict["_packed_factor"] = pack_factor - else: - params_dict["weight"] = torch.empty(output_size, - input_size, - dtype=torch.int8) - - return params_dict - - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - @staticmethod - def get_perchannel_param(output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: - params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_scale_second"] = torch.empty(output_size, - input_size // - self.group_size, - dtype=params_dtype) - params_dict["weight_offset_second"] = torch.empty(output_size, - input_size // - self.group_size, - dtype=params_dtype) - - if self.new_quant_version: - scale_bias_dim = 16 if layer_type == "row" else 1 - params_dict["scale_bias"] = torch.empty(output_size, - scale_bias_dim, - dtype=torch.float32) - return params_dict - - @staticmethod - def process_scale_second(weight: torch.Tensor, - scale: torch.Tensor, - per_group_scale: torch.Tensor, - is_new_quant: bool = False): - k, n = weight.shape - group_num, n_scale = per_group_scale.shape - - if is_new_quant: - n = n * 2 - - bias = None - if not is_new_quant: - weight_high = weight.to(torch.float32).reshape( - group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) - weight_high = weight_high.reshape(k, n) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) - - antiquant_scale = (scale * per_group_scale).reshape(group_num, n) - return antiquant_scale.npu(), bias - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = None, - ) -> torch.Tensor: - return torch_npu.npu_weight_quant_batchmatmul( - x, - layer.weight, - antiquant_scale=layer.weight_scale_second.to(x.dtype), - antiquant_group_size=self.group_size, - ) - - def process_weights_after_loading(self, layer: torch.nn.Module): - if self.transpose_weight: - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight_scale.data = layer.weight_scale.data.flatten().to( - torch.float32) - layer.weight_offset.data = layer.weight_offset.data.flatten() - layer.weight_scale_second.data, scale_bias = self.process_scale_second( - layer.weight.data, - layer.weight_scale.data, - layer.weight_scale_second.data.transpose(0, 1).contiguous(), - is_new_quant=self.new_quant_version, - ) - - if self.new_quant_version: - if hasattr(layer, "scale_bias"): - if layer.scale_bias.data.shape[1] == 1: - layer.scale_bias.data = layer.scale_bias.data.flatten() - else: - layer.scale_bias.data = layer.scale_bias.data.contiguous() - else: - if scale_bias is not None: - param = torch.nn.Parameter(scale_bias, requires_grad=False) - layer.register_parameter("weight_scale_bias", param) - - if self.new_quant_version: - assert layer.weight.data.shape[-1] % 4 == 0, \ - f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" - layer.weight.data = layer.weight.data.view( - torch.int32).contiguous() - else: - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) - - -class TorchairAscendW4A8DynamicFusedMoEMethod: - """FusedMoe method for Ascend W4A8_DYNAMIC. - """ - - def __init__(self): - self.transpose_weight = True - - self.ep_group = get_ep_group() - - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 256) - # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process - self.is_per_channel_weight = self.group_size == 0 - quant_version = vllm_config.quant_config.quant_description.get( - "version", "0") - # NOTE: new quantize weights: 2 int4 pack into int8 - self.new_quant_version = quant_version == "1.0.0" - self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size - if self.new_quant_version and self.tp_size > 16: - raise ValueError( - "The current weight does not support moe part tp>16.") - - try: - device_group = get_mc2_group().device_group - # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=device_group) - backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) - except AttributeError: - self.moe_all_to_all_group_name = "" - - def get_weight(self, num_experts: int, - intermediate_size_per_partition: int, hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - param_dict = {} - if self.new_quant_version: - w13_output_size = intermediate_size_per_partition - w2_output_size = hidden_sizes // 2 - else: - w13_output_size = 2 * intermediate_size_per_partition - w2_output_size = hidden_sizes - - param_dict["w13_weight"] = torch.empty(num_experts, - w13_output_size, - hidden_sizes, - dtype=torch.int8) - param_dict["w2_weight"] = torch.empty(num_experts, - w2_output_size, - intermediate_size_per_partition, - dtype=torch.int8) - return param_dict - - def get_dynamic_quant_param(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - param_dict = {} - param_dict["w13_weight_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32) - - param_dict["w13_weight_offset"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32) - - param_dict["w2_weight_scale"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=torch.float32) - param_dict["w2_weight_offset"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=torch.float32) - - if not self.is_per_channel_weight: - param_dict["w13_weight_scale_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.float32) - param_dict["w13_weight_offset_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.float32) - - param_dict["w2_weight_scale_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32) - param_dict["w2_weight_offset_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32) - - if self.new_quant_version: - param_dict["w13_scale_bias"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32) - param_dict["w2_scale_bias"] = torch.empty(num_experts, - hidden_sizes, - 16 // self.tp_size, - dtype=torch.float32) - - return param_dict - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = True, - enable_force_load_balance: bool = True, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - **kwargs, - ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, router_logits) - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - - topk_weights = topk_weights.to(x.dtype) - if fused_moe_state == FusedMoEState.MC2: - return torchair_fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_scale_bias=layer.w13_scale_bias, - w2_scale_bias=layer.w2_scale_bias, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled, - quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None), - dynamic_eplb=self.dynamic_eplb) - else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into layers module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. - return torchair_fused_experts_with_all2all( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_scale_bias=layer.w13_scale_bias, - w2_scale_bias=layer.w2_scale_bias, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=self.ep_group, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - ) - - def process_scale(self, weight: torch.Tensor, scale, per_group_scale): - scale = scale.transpose(1, 2).contiguous() - if self.is_per_channel_weight: - scale_np = scale.cpu().numpy() - scale_np.dtype = np.uint32 - scale_uint64_tensor = torch.from_numpy(scale_np.astype( - np.int64)).npu() - return scale_uint64_tensor, None - per_group_scale = per_group_scale.transpose(1, 2).contiguous() - group_num, k, n = weight.shape - # the weight of the new version is reduced by half by pack n, so it needs to be restored - if self.new_quant_version: - n = n * 2 - per_group_scale = per_group_scale.reshape(group_num, -1, n) - group_num, quantgroup_num, n = per_group_scale.shape - bias = None - if not self.new_quant_version: - weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ - per_group_scale.reshape([group_num, quantgroup_num, 1, n]) - weight_high = weight_high.reshape([group_num, k, n]) - bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) - scale_fp32 = (scale * per_group_scale).to(torch.float16).to( - torch.float32) - scale_fp32_np = scale_fp32.cpu().numpy() - scale_fp32_np.dtype = np.uint32 - sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), - dtype=np.uint32) - - sscale_uint64[..., ::2] = scale_fp32_np - - sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), - dtype=np.int64).copy() - sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( - group_num, quantgroup_num, n) - sscale_uint64_tensor = sscale_uint64_tensor.npu() - return sscale_uint64_tensor, bias - - def update_bias(self, layer, w13_bias, w2_bias): - if self.new_quant_version: - layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( - 1, 2).contiguous().sum(axis=1) - layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose( - 1, 2).contiguous().sum(axis=1) - else: - w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) - layer.register_parameter("w13_scale_bias", w13_scale_bias) - w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) - layer.register_parameter("w2_scale_bias", w2_scale_bias) - - def pack_to_int32(self, weight: torch.Tensor): - if self.new_quant_version: - # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 - assert weight.shape[ - -1] % 4 == 0, "the last dim of weight needs to be divided by 4" - return weight.view(torch.int32).contiguous() - else: - return torch_npu.npu_quantize(weight.to(torch.float32), - torch.tensor([1.]).npu(), None, - torch.quint4x2, -1, False) - - def process_weights_after_loading(self, layer): - if self.transpose_weight: - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose( - 1, 2).contiguous() - w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( - layer, "w13_weight_scale_second") else None - w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( - layer, "w2_weight_scale_second") else None - layer.w13_weight_scale.data, w13_bias = self.process_scale( - layer.w13_weight, layer.w13_weight_scale.data, - w13_weight_scale_second) - layer.w2_weight_scale.data, w2_bias = self.process_scale( - layer.w2_weight, layer.w2_weight_scale.data, - w2_weight_scale_second) - if hasattr(layer, "w13_weight_scale_second"): - # scale_second is no longer used, release this part of the memory - del layer.w13_weight_scale_second - del layer.w2_weight_scale_second - del layer.w13_weight_offset_second - del layer.w2_weight_offset_second - - self.update_bias(layer, w13_bias, w2_bias) - - layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) - layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py deleted file mode 100644 index 8909bb79..00000000 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ /dev/null @@ -1,1082 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch_npu -from vllm.distributed import GroupCoordinator, get_ep_group -from vllm.forward_context import get_forward_context - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import FusedMoEState -from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts -from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor, - super_kernel) -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, - dispose_tensor, get_ascend_device_type, - is_enable_nz, - is_hierarchical_communication_enabled) - - -def torchair_apply_mlp_decode(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - Args: - hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - w2_scale: weights2 scale with shape (num_experts, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - Returns: - hidden_states: output hidden states after MLP. - """ - - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale, - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=group_list, - activate_left=True, - quant_mode=1, - ) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=w2_scale.dtype)[0] - return hidden_states - - -def torchair_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None) -> torch.Tensor: - """ - apply MLP: gate_up_proj -> swiglu -> down_proj - - Args: - hidden_states: input hidden states with shape (num_tokens, hidden_size). - w1: expert weights1 with shape - (num_experts, hidden_size, intermediate_size * 2) - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) - w2: expert weights2 with shape - (num_experts, intermediate_size, hidden_size) - w2_scale: weights2 scale with shape (num_experts, hidden_size) - group_list: number of tokens for each expert, follow cumsum mode, and - with shape (num_experts). - transpose_weight: - w1: (num_experts, intermediate_size * 2, hidden_size) -> - (num_experts, hidden_size, intermediate_size * 2) - w2: (num_experts, hidden_size, intermediate_size) -> - (num_experts, intermediate_size, hidden_size) - - Returns: - hidden_states: output hidden states after MLP. - """ - - if dynamic_scale is None: - unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) - # Dispose the original unquantized hidden states - # to save npu memory because they're no longer used. - dispose_tensor(unquantized_hidden_states) - else: - pertoken_scale = dynamic_scale - - bias1, bias2 = None, None - _output_dtype = w2_scale.dtype - - if w1_scale_bias is not None: - if group_list_type == 0: - group_list = torch.cat( - [group_list[:1], torch.diff(group_list, dim=0)]) - group_list_type = 1 - bias1 = [w1_scale_bias] - bias2 = [w2_scale_bias] - # TODO w4a8 scene: dynamic acquisition of dtype in the future - _output_dtype = torch.bfloat16 - - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - scale=[w1_scale], - bias=bias1, - per_token_scale=[pertoken_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) - - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - scale=[w2_scale], - bias=bias2, - per_token_scale=[swiglu_out_scale], - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - output_dtype=_output_dtype)[0] - - return hidden_states - - -def torchair_fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: str = "", - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - is_torchair: bool = False, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - dynamic_eplb: bool = False, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - assert mc2_mask is not None - if log2phy is not None: - topk_ids = log2phy[topk_ids] - - quant_mode = 2 - ep_group = get_mc2_group() - ep_rank_id = ep_group.rank_in_group - ep_world_size = ep_group.world_size - - # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93 - or is_torchair) - - # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93 - # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and - # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly - # improve communication performance. - need_expert_scale = is_hierarchical_communication_enabled() - - enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - - if (expert_map is not None): - moe_expert_num = len(expert_map) - else: - moe_expert_num = global_redundant_expert_num - # hidden_states = hidden_states.bfloat16() - kwargs_mc2 = { - "x": hidden_states, - "expert_ids": topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - - stage1_kwargs = { - "scales": None, - "quant_mode": quant_mode, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) - if need_expert_scale: - stage1_kwargs.update({ - "expert_scales": topk_weights.to(torch.float32), - }) - kwargs_mc2.update(stage1_kwargs) - - output = torch_npu.npu_moe_distribute_dispatch_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( - **kwargs_mc2) - # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ - ep_recv_counts, _, expand_scales = output[0:7] - - if shared_experts is not None: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_gate_up, expand_x) - shared_act_out = shared_experts.act_fn( - (shared_gate_up, shared_dequant_scale)) - shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] - - # `expand_x` will be disposed in the `apply_mlp` function - if w1_scale_bias is None: - down_out_list = torchair_apply_mlp_decode(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) - else: - # w4a8 scene, cannot use apply_mlp_decode because the operator is not supported - down_out_list = torchair_apply_mlp(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) - - # moeCombine - kwargs_mc2 = { - "expand_x": down_out_list, - "expert_ids": topk_ids, - "expert_scales": topk_weights.to(torch.float32), - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - tp_recv_counts = torch.empty(1, - dtype=torch.int32, - device=hidden_states.device) - stage3_kwargs = { - "ep_send_counts": ep_recv_counts, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - "expand_scales": expand_scales, - } - if enable_dispatch_v2: - stage3_kwargs.update({ - "assist_info_for_combine": - assist_info_for_combine, - }) - else: - stage3_kwargs.update({ - "expand_idx": assist_info_for_combine, - }) - if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage3_kwargs.update({ - "x_active_mask": mc2_mask, - }) - kwargs_mc2.update(stage3_kwargs) - - hidden_states = torch_npu.npu_moe_distribute_combine_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( - **kwargs_mc2) - - if shared_experts is None: - if dynamic_eplb: - return (hidden_states, 1, expert_token_nums) - return hidden_states - else: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act, down_out_list) - shared_output, _ = shared_experts.down_proj( - (shared_act, swiglu_out_scale)) - if dynamic_eplb: - return (hidden_states, shared_output, 1, expert_token_nums) - return (hidden_states, shared_output) - - -def torchair_init_routing_quant(hidden_states, top_k, topk_ids, - global_num_experts): - num_tokens, _ = hidden_states.shape - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=hidden_states.device).view( - top_k, -1).permute(1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( - 1, 0).contiguous().view(-1)) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - global_expert_tokens = global_expert_tokens.to(torch.int32) - quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) - return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales - - -# currently expert parallelism implemented with all2all -# is under-optimized. -def torchair_fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, -): - if log2phy is not None: - topk_ids = log2phy[topk_ids] - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - - if expert_map is not None: - assert ep_group is not None, "ep_group must be provided when expert_map is given" - global_num_experts = len(expert_map) - if hasattr(torch_npu, "npu_moe_init_routing_quant"): - quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( - hidden_states, - expert_idx=topk_ids.to(torch.int32), - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_num_mode=2, - expert_tokens_before_capacity_flag=False, - quant_mode=1, - ) - else: - quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = torchair_init_routing_quant( - hidden_states, top_k, topk_ids, global_num_experts) - - gather_sizes = global_expert_tokens.new_empty( - global_expert_tokens.shape[0]) - dist.all_to_all_single(gather_sizes, - global_expert_tokens, - group=ep_group.device_group) - token_counts_combined = torch.stack( - [gather_sizes, global_expert_tokens], dim=0) - token_counts_combined = token_counts_combined.view( - 2, ep_group.world_size, -1).sum(dim=2) - token_counts_combined_cpu = token_counts_combined.to( - torch.device("cpu"), non_blocking=False).numpy() - all_tokens = gather_sizes.sum() - - gathered_tokens = quantized_tokens.new_empty(all_tokens.item(), - quantized_tokens.shape[1]) - dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0]) - gather_size_list = token_counts_combined_cpu[1] - scatter_size_list = token_counts_combined_cpu[0] - - dist.all_to_all_single(gathered_tokens, - quantized_tokens, - scatter_size_list, - gather_size_list, - group=ep_group.device_group) - dist.all_to_all_single(dynamic_scale, - token_scales, - scatter_size_list, - gather_size_list, - group=ep_group.device_group) - - hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( - gathered_tokens, - gather_sizes.view(ep_group.world_size, -1), - per_token_scales=dynamic_scale) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 1 - else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 0 - dynamic_scale = None - - # `hidden_states` will be disposed in the `apply_mlp` function - hidden_states = torchair_apply_mlp( - hidden_states, - w1, - w1_scale, #17 - w2, - w2_scale, - expert_tokens, #16 - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) - - if expert_map is not None: - reordered_outputs = torch.index_select( - hidden_states, - dim=0, - # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU - index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) - - hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) - dist.all_to_all_single(hidden_states, - reordered_outputs, - gather_size_list, - scatter_size_list, - group=ep_group.device_group) - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=None, - drop_pad_mode=2) - else: - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - -def torchair_fused_experts_with_allgather(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - num_tokens = hidden_states.shape[0] - batch_size, hidden_size = hidden_states.shape - topk_weights = topk_weights.to(hidden_states.dtype) - - ep_group = get_ep_group().device_group - ep_rank = torch.distributed.get_rank(group=ep_group) - ep_size = torch.distributed.get_world_size(ep_group) - - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_size - - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) - - hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2( - hidden_states, - topk_ids, - scale=pertoken_scale, - offset=None, - active_num=num_tokens * top_k, - expert_num=global_num_experts, - expert_tokens_num_type=1, - expert_tokens_num_flag=True, - active_expert_range=[ - ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts - ], - quant_mode=-1, - row_idx_type=1) - group_list_type = 1 - - sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0, - expanded_x_idx) - row_index = expanded_x_idx // topk_ids.shape[-1] - row_index = row_index.to(torch.int64) - share_input = torch.zeros((batch_size, hidden_size), - dtype=torch.bfloat16, - device="npu") - - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=expert_tokens, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale.to(torch.float32), - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=expert_tokens, - activate_left=True, - quant_mode=1, - ) - - final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing( - hidden_states, - w2, - scale=w2_scale.to(torch.float32), - bias=None, - pertoken_scale=pertoken_scale.view(-1), - group_list=expert_tokens, - shared_input=share_input, - logit=sorted_topk_weight.to(torch.float32), - row_index=row_index, - output_bs=batch_size).to(torch.bfloat16) - - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - - return final_hidden_states - - -def torchair_fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - dtype = hidden_states.dtype - device = hidden_states.device - - if expert_map is not None: - # Generate token indices and flatten - token_indices = (torch.arange(num_tokens, - device=device, - dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - filtered_weights = torch.where( - mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) - filtered_experts = torch.where( - mask, local_experts_flat, - torch.full_like(local_experts_flat, - num_experts)).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts) - sorted_token_indices = token_indices[sort_indices] - sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - expert_tokens = token_counts[:num_experts] - # Rearrange hidden_states - hidden_states = hidden_states[sorted_token_indices] - group_list_type = 1 - else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 0 - - # `hidden_states` will be disposed in the `apply_mlp` function - hidden_states = torchair_apply_mlp(hidden_states, - w1, - w1_scale, - w2, - w2_scale, - expert_tokens, - group_list_type=group_list_type) - - if expert_map is not None: - hidden_states.mul_(sorted_weights.unsqueeze(1)) - final_hidden_states = torch.zeros(*original_shape, - device=device, - dtype=dtype) - - num_valid_tokens = mask.sum() - valid_token_mask = torch.arange( - 0, sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens - hidden_states = hidden_states.masked_fill_(~valid_token_mask, - 0).to(dtype) - final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) - else: - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - -class TorchairAscendW8A8DynamicLinearMethod: - """Linear method for Ascend W8A8_DYNAMIC. - """ - - def __init__(self): - self.transpose_weight = True - - @staticmethod - def get_weight(input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } - return params_dict - - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - @staticmethod - def get_perchannel_param( - output_size: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=params_dtype) - return params_dict - - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: - return {} - - @staticmethod - def apply( - layer: torch.nn.Module, - x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, - ) -> torch.Tensor: - config = getattr(layer, "_ascend_quant_config", {}) - if not isinstance(x, tuple): - output_dtype = config.get("output_dtype", x.dtype) - quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) - else: - assert "output_dtype" in config.keys(), ( - f"DynamicLinearMethod needs explicitly specified `output_dtype`" - f"for pre-quantized input, got config [{config}]") - output_dtype = config["output_dtype"] - quantized_x, dynamic_scale = x - pertoken_scale = (dynamic_scale - if config.get("pertoken_scale", True) else None) - - output = torch_npu.npu_quant_matmul( - quantized_x, - layer.weight, - layer.weight_scale, - pertoken_scale=pertoken_scale, - bias=bias, - output_dtype=output_dtype, - ) - return ((output, dynamic_scale) - if config.get("return_scale", False) else output) - - def process_weights_after_loading(self, layer): - if self.transpose_weight: - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - # cast quantized weight tensors in NZ format (29) for higher inference speed - if is_enable_nz(): - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, 29) - layer.weight_scale.data = layer.weight_scale.data.flatten() - layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) - layer.weight_offset.data = layer.weight_offset.data.flatten() - - -class TorchairAscendW8A8DynamicFusedMoEMethod: - """FusedMoe method for Ascend W8A8_DYNAMIC. - """ - - def __init__(self): - self.transpose_weight = True - - self.ep_group = get_ep_group() - - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - try: - device_group = get_mc2_group().device_group - # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=device_group) - backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) - except AttributeError: - self.moe_all_to_all_group_name = "" - - @staticmethod - def get_weight(num_experts: int, intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - param_dict = {} - param_dict["w13_weight"] = torch.empty(num_experts, - 2 * - intermediate_size_per_partition, - hidden_sizes, - dtype=torch.int8) - param_dict["w2_weight"] = torch.empty(num_experts, - hidden_sizes, - intermediate_size_per_partition, - dtype=torch.int8) - return param_dict - - @staticmethod - def get_dynamic_quant_param(num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - param_dict = {} - param_dict["w13_weight_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=params_dtype) - param_dict["w13_weight_offset"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=params_dtype) - param_dict["w2_weight_scale"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=params_dtype) - param_dict["w2_weight_offset"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=params_dtype) - return param_dict - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = True, - enable_force_load_balance: bool = True, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - prefix: str = "", - running_in_super_kernel: bool = False, - **kwargs, - ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - - is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256 - - fused_moe_state = get_forward_context().fused_moe_state - if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: - fused_moe_state = FusedMoEState.All2All - shared_gate_up, shared_dequant_scale = None, None - - with super_kernel(prefix, - "stream-fusion=1", - enabled=running_in_super_kernel): - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently is 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = torchair_select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, router_logits) - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like( - topk_ids, 0, - global_num_experts - global_redundant_expert_num) - topk_weights = topk_weights.to(x.dtype) - - if fused_moe_state == FusedMoEState.AllGatherEP: - return torchair_fused_experts_with_allgather( - hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) - elif fused_moe_state == FusedMoEState.MC2: - with super_kernel(prefix, - "stream-fusion=1", - enabled=running_in_super_kernel): - return torchair_fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled, - mc2_mask=kwargs.get("mc2_mask", None), - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale, - dynamic_eplb=self.dynamic_eplb) - elif fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.NaiveMulticast - ]: - return torchair_fused_experts(hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) - else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into layers module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. - return torchair_fused_experts_with_all2all( - hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=self.ep_group, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - ) - - def process_weights_after_loading(self, layer): - if self.transpose_weight: - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose( - 1, 2).contiguous() - if is_enable_nz(): - torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) - torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) - layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( - layer.w13_weight_scale.data.shape[0], -1) - layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( - torch.float32) - layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( - layer.w13_weight_offset.data.shape[0], -1) - layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( - layer.w2_weight_scale.data.shape[0], -1) - layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( - layer.w2_weight_offset.data.shape[0], -1) diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py deleted file mode 100644 index 4afa65e1..00000000 --- a/vllm_ascend/torchair/torchair_attention.py +++ /dev/null @@ -1,457 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -from dataclasses import dataclass -from typing import List, Optional, Tuple, Type - -import numpy as np -import torch -import torch.nn as nn -import torch_npu -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, - AttentionType) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig -from vllm.utils.math_utils import cdiv - -from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, - AscendAttentionMetadataBuilder, - AscendAttentionState, - AscendMetadata) -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, - aligned_16, get_ascend_device_type, nd_to_nz_2d) - - -class AscendAttentionTorchairBackend(AscendAttentionBackend): - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ASCEND_TORCHAIR" - - @staticmethod - def get_impl_cls() -> Type["AscendAttentionTorchairBackendImpl"]: - return AscendAttentionTorchairBackendImpl - - @staticmethod - def get_builder_cls() -> type["AscendAttentionTorchairMetadataBuilder"]: - return AscendAttentionTorchairMetadataBuilder - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (2, num_blocks, block_size, num_kv_heads * head_size) - - @staticmethod - def get_bsh_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (2, num_blocks, block_size, num_kv_heads * head_size) - - -@dataclass -class AscendDecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - seq_lens: torch.Tensor - max_seq_lens: int - seq_lens_list: list[int] - attn_mask: Optional[torch.Tensor] = None - - -@dataclass -class AscendTorchairMetadata(AscendMetadata): - - decode: Optional[AscendDecodeMetadata] = None - - -class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder): - - def __init__( - self, - kv_cache_spec, - layer_names, - vllm_config: VllmConfig, - device: torch.device, - ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.max_num_blocks_per_req = cdiv( - self.model_config.max_model_len, - self.vllm_config.cache_config.block_size) - self.max_blocks = (self.model_config.max_model_len + - self.vllm_config.cache_config.block_size - - 1) // self.vllm_config.cache_config.block_size - - def _get_graph_runner_block_tables( - self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - max_blocks = self.max_blocks - - graph_block_tables = torch.zeros((num_seqs, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] - else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:, :max_blocks] - - def build_torchair_graph_dummy( - self, common_attn_metadata: TorchairCommonAttentionMetadata - ) -> AscendTorchairMetadata: - device = self.device - num_reqs = common_attn_metadata.num_reqs - block_table = torch.zeros((num_reqs, self.max_blocks), - dtype=torch.int32, - device=device) - block_table = self._get_graph_runner_block_tables( - num_reqs, block_table) - seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) - input_positions = torch.zeros(num_reqs, - dtype=torch.int32, - device=device).long() - slot_mapping = torch.full((num_reqs, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - query_start_loc = torch.full((num_reqs, ), - -1, - dtype=torch.int32, - device=device) - - decode_metadata = AscendDecodeMetadata(input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), - max_seq_lens=1) - - attn_metadata = AscendTorchairMetadata( - num_actual_tokens=common_attn_metadata.num_actual_tokens, - block_tables=block_table, - query_lens=0, - query_start_loc=query_start_loc, - seq_lens=seq_lens, - slot_mapping=slot_mapping, - attn_state=AscendAttentionState.DecodeOnly, - decode=decode_metadata) - return attn_metadata - - def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: Optional[nn.Module] = None, - ): - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - - block_table = common_attn_metadata.block_table_tensor - block_table[:num_reqs, :self.max_num_blocks_per_req] = ( - block_table[:num_reqs]) - - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] - attn_mask = common_attn_metadata.attn_mask - - attn_state = common_attn_metadata.attn_state - if get_ascend_device_type( - ) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache: - mask_nz = nd_to_nz_2d(attn_mask) - attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29) - - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_reqs - + 1] - query_start_loc = query_start_loc_cpu.to(self.device, - non_blocking=True) - query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - input_positions = common_attn_metadata.positions[: - num_actual_tokens].long( - ) - - decode_metadata = None - graph_pad_size = common_attn_metadata.graph_pad_size - use_torchair_graph = graph_pad_size > -1 - if common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - ]: - max_seq_lens = seq_lens.max().item() - num_seqs = len(seq_lens) - if use_torchair_graph and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - ]: - num_reqs_pad_size = 0 - num_token_pad_size = 0 - if graph_pad_size != 0: - pad_value = 0 - num_token_pad_size = graph_pad_size - num_actual_tokens - num_reqs_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) - pad_value = 1 - padded_seq_lens = seq_lens.tolist() + [pad_value - ] * num_reqs_pad_size - - seq_lens = torch.from_numpy( - np.array(padded_seq_lens).astype(np.int32)) - padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, padding]) - block_table_padding = torch.zeros( - (num_reqs_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], - dim=0) - block_table = self._get_graph_runner_block_tables( - num_seqs + num_reqs_pad_size, block_table) - padding_0 = torch.zeros(num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat([input_positions, padding_0]) - - decode_metadata = AscendDecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens.tolist(), - max_seq_lens=max_seq_lens, - attn_mask=None) - - attn_metadata = AscendTorchairMetadata( - decode=decode_metadata, - num_actual_tokens=num_actual_tokens, - block_tables=block_table, - query_start_loc=query_start_loc, - query_lens=query_lens, - seq_lens=seq_lens, - max_query_len=common_attn_metadata.max_query_len, - slot_mapping=slot_mapping, - attn_mask=attn_mask, - attn_state=attn_state) - return attn_metadata - - -class AscendAttentionTorchairBackendImpl(AttentionImpl): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - **kwargs, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.hidden_size = self.num_heads * self.head_size - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = sliding_window - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, - dtype=torch.float32, - device="npu") - self.alibi_slopes = alibi_slopes - self.attn_type = attn_type - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.key_cache = None - self.value_cache = None - self.scale_tensor = torch.zeros((), device='npu', dtype=torch.int32) - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AscendTorchairMetadata, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with Ascend attention. - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache: shape = [2, num_blocks, block_size, - num_kv_heads, head_size] - key_cache = [num_blocks, block_size, - num_kv_heads, head_size] - value_cache = [num_blocks, block_size, - num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [batch_size * seq_len, num_heads, head_size] - """ - num_tokens = query.shape[0] - use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0 - and kv_cache[0].numel() > 0 - and kv_cache[0].dtype == torch.int8) - if output is None: - output = torch.empty(num_tokens, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device) - - if hasattr(layer, 'quant_method') and use_kv_cache_quant: - output = layer.quant_method.apply(layer, query, key, value, - kv_cache, attn_metadata, - self.attn_type, self.scale, - output) - return output.view(num_tokens, self.hidden_size) - - if attn_metadata is None: - return output.view(num_tokens, self.hidden_size).fill_(0) - - output = output.view(-1, self.num_heads, self.head_size) - - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "AscendAttentionTorchairBackendImpl") - - if kv_cache is not None and kv_cache[0].numel() > 0: - key_cache, value_cache = kv_cache[0], kv_cache[1] - slots = attn_metadata.slot_mapping - - block_size = self.scale_tensor + key_cache.shape[1] - slots_indices = slots.reshape(-1, 1) - block_indices = slots_indices // block_size - slots_indices = slots_indices % block_size - indices = torch.cat((block_indices, slots_indices), dim=1) - torch_npu.npu_scatter_nd_update_(key_cache, indices, key) - torch_npu.npu_scatter_nd_update_(value_cache, indices, value) - if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: - self.key_cache = key_cache - self.value_cache = value_cache - - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - mask = attn_metadata.attn_mask - - # View q k v to BSH. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if get_ascend_device_type() == AscendDeviceType._310P: - # align q k v output tensors - query = aligned_16(query) - key = aligned_16(key) - value = aligned_16(value) - output = aligned_16(output) - - # do reformat in case of broadcasted tensors - mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1) - mask = torch_npu.npu_format_cast(mask.contiguous(), - ACL_FORMAT_FRACTAL_NZ) - - torch_npu._npu_flash_attention(query=query, - key=key, - value=value, - mask=mask, - seq_len=attn_metadata.seq_lens, - scale_value=self.scale, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=output) - output = output[:num_tokens, :, :] - elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: - assert attn_metadata is not None - assert attn_metadata.attn_mask is not None - compress_mask = attn_metadata.attn_mask - batch_size = attn_metadata.query_lens.shape[0] - block_table = attn_metadata.block_tables[:batch_size, :] - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - block_table=block_table, - mask=compress_mask, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - decode_meta = attn_metadata.decode - assert decode_meta is not None - seq_lens = decode_meta.seq_lens_list - block_table = decode_meta.block_table - block_size = key_cache.shape[1] - query = query.view(num_tokens, 1, - self.num_heads * self.head_size).contiguous() - output, _ = torch_npu.npu_fused_infer_attention_score( - query=query, - key=key_cache, - value=value_cache, - query_rope=None, - key_rope=None, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout='BSH', - atten_mask=decode_meta.attn_mask, - sparse_mode=0, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=block_table, - block_size=block_size, - actual_seq_lengths_kv=seq_lens, - ) - else: - raise NotImplementedError( - "Torchair graph mode with non-MLA attention backend is still experimental." - "v1 scheduler(chunked prefill) is not supported at this moment." - ) - - return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py deleted file mode 100644 index 5846bbd2..00000000 --- a/vllm_ascend/torchair/torchair_mla.py +++ /dev/null @@ -1,1263 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar - -import numpy as np -import torch -import torch.nn as nn -import torch_npu -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - MLAAttentionImpl) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.utils.math_utils import cdiv, round_down - -import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills) -from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch -from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, - npu_stream_switch, npu_wait_tensor) -from vllm_ascend.worker.npu_input_batch import InputBatch - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - - -class AscendMLATorchairBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ASCEND_MLA_TORCHAIR" - - @staticmethod - def get_builder_cls(): - return AscendMLATorchairMetadataBuilder - - @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, - head_size: int) -> tuple[int, ...]: - return (num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def get_impl_cls() -> Type["MLAAttentionImpl"]: - return AscendMLATorchairImpl - - -@dataclass -class AscendMLATorchairPrefillMetadata: - """ Prefill Specific Metadata for Ascend""" - - @dataclass - class TorchairChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - chunk_seq_lens: torch.Tensor - chunk_seq_lens_npu: torch.Tensor - - attn_mask: torch.Tensor - query_lens: torch.Tensor - seq_lens: list[int] - context_lens: torch.Tensor - input_positions: torch.Tensor - query_start_loc: torch.Tensor - block_table: torch.Tensor - max_query_len: int - max_seq_lens: int - chunked_context: Optional[TorchairChunkedContextMetadata] = None - sin: torch.Tensor = None - cos: torch.Tensor = None - - -@dataclass -class AscendMLATorchairDecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - seq_lens: torch.Tensor - max_seq_lens: int - seq_lens_list: list[int] - actual_seq_lengths_q: Optional[list[int]] = None - attn_mask: Optional[torch.Tensor] = None - sin: torch.Tensor = None - cos: torch.Tensor = None - - -@dataclass -class AscendMLATorchairMetadata: - """Metadata for MLACommon. - - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - slot_mapping: torch.Tensor - query_start_loc: torch.Tensor - seq_lens: torch.Tensor - block_tables: torch.Tensor - - # New for MLA (compared to FlashAttention) - # For handling prefill decode split - num_decodes: int - num_decode_tokens: int - num_prefills: int - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - - query_lens: Optional[list[int]] = None - # The dimension of the attention heads - head_dim: Optional[int] = None - attn_mask: torch.Tensor = None - # chunked prefill by default if no attn_states passed - attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - - decode: Optional[AscendMLATorchairDecodeMetadata] = None - prefill: Optional[AscendMLATorchairPrefillMetadata] = None - - def __post_init__(self): - pass - # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() - # if self.head_dim is not None and self.head_dim \ - # not in supported_head_sizes: - # raise ValueError( - # f"Only {supported_head_sizes} are supported for head_dim,", - # f"received {self.head_dim}.") - - -M = TypeVar("M", bound=AscendMLATorchairMetadata) - - -class AscendMLATorchairMetadataBuilder: - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - # _attn_mask_builder = None - def __init__(self, - kv_cache_spec, - layer_names, - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendMLATorchairMetadata] = None): - self.metadata_cls: Optional[AscendMLATorchairMetadata] = metadata_cls \ - if metadata_cls is not None else AscendMLATorchairMetadata # type: ignore - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.device = device - scheduler_config = vllm_config.scheduler_config - self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + - self.block_size - 1) // self.block_size - self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * self.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * self.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.cos_cache = None - self.sin_cache = None - - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - # We now want to reorder the batch so that the "decode" requests are at - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - # For torch air graph mode we treat spec decoding as decode. - if self.torchair_graph_enabled: - if num_tokens - num_spec_tokens == 1: - decodes.append(i) - else: - prefills.append(i) - # For eager mode we treat spec decoding as chunked prefill. - else: - if num_tokens == 1: - decodes.append(i) - else: - prefills.append(i) - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - first_prefill = 0 - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: - break - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - return modified_batch - - def _get_graph_runner_block_tables( - self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - max_blocks = self.max_blocks - - graph_block_tables = torch.zeros((num_seqs, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] - else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:, :max_blocks] - - def build_torchair_graph_dummy( - self, - common_attn_metadata: TorchairCommonAttentionMetadata, - ) -> AscendMLATorchairMetadata: - device = self.device - num_reqs = common_attn_metadata.num_reqs - block_table = torch.zeros((num_reqs, self.max_blocks), - dtype=torch.int32, - device=device) - block_table = self._get_graph_runner_block_tables( - num_reqs, block_table) - num_tokens = num_reqs * common_attn_metadata.decode_token_per_req - seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) - seq_lens_list = [0] * num_reqs - input_positions = torch.zeros(num_tokens, - dtype=torch.int32, - device=device).long() - slot_mapping = torch.full((num_tokens, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - query_start_loc = torch.full((num_reqs, ), - -1, - dtype=torch.int32, - device=device) - sin = torch.ones(num_tokens, - 1, - 1, - self.rope_dim, - dtype=self.model_config.dtype, - device=device) - cos = torch.ones(num_tokens, - 1, - 1, - self.rope_dim, - dtype=self.model_config.dtype, - device=device) - if self.vllm_config.speculative_config is not None and\ - self.vllm_config.speculative_config.method == 'mtp': - attn_state = AscendAttentionState.SpecDecoding - num_decode_tokens = 2 - else: - attn_state = AscendAttentionState.DecodeOnly - num_decode_tokens = 1 - decode_metadata = AscendMLATorchairDecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=1, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=common_attn_metadata. - actual_seq_lengths_q[:num_reqs], - sin=sin, - cos=cos, - ) - return self.metadata_cls( # type: ignore - num_input_tokens=common_attn_metadata.num_actual_tokens, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=1, - num_decode_tokens=num_decode_tokens, - num_prefills=0, - attn_mask=common_attn_metadata.attn_mask, - attn_state=attn_state, - prefill=None, - decode=decode_metadata, - query_start_loc=query_start_loc, - seq_lens=seq_lens, - block_tables=block_table, - ) - - def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, - ) -> AscendMLATorchairMetadata: - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - decode_threshold = common_attn_metadata.decode_token_per_req - else: - # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding - decode_threshold = 1 - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens - - # Note(simon): be careful about the CPU <> GPU memory movement in this - # function. We should avoid GPU -> CPU sync as much as possible because - # it blocks on all previous kernels. - device = self.device - - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] - input_positions = common_attn_metadata.positions[: - num_actual_tokens].long( - ) - - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - 0].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - 0].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - - query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - query_lens = query_seq_lens_cpu[:num_reqs] - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = (seq_lens - query_lens) - - prefill_metadata = None - chunked_context_metadata = None - if num_prefills > 0: - reqs_start = num_decodes # prefill_start - tokens_start = num_decode_tokens - max_query_len = query_lens[reqs_start:].max().item() - max_seq_lens = seq_lens[reqs_start:].max().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] - - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - max_context_len_cpu = context_lens_cpu.max().item() - num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - if self.chunked_prefill_enabled and max_context_len_cpu > 0: - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) - max_context_chunk = round_down(max_context_chunk, - self.block_size) - - assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - chunked_context_metadata = \ - AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens.npu(), - workspace=self.chunked_prefill_workspace, - ) - prefill_input_positions = input_positions[tokens_start:] - cos = self.cos_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - prefill_metadata = AscendMLATorchairPrefillMetadata( - attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[reqs_start:].to(torch.int32), - seq_lens=seq_lens, - context_lens=seq_lens[reqs_start:], - input_positions=prefill_input_positions, - block_table=block_table[reqs_start:, ...], - max_query_len=max_query_len, - max_seq_lens=max_seq_lens, - query_start_loc=prefill_query_start_loc, - chunked_context=chunked_context_metadata, - sin=sin, - cos=cos, - ) - - decode_metadata = None - graph_pad_size = common_attn_metadata.graph_pad_size - use_torchair_graph = graph_pad_size != -1 - if num_decodes > 0: - # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist() - max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decodes] - input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes, ...] - num_token_pad_size = 0 - if use_torchair_graph and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - num_reqs_pad_size = 0 - if graph_pad_size != 0: - pad_value = 0 - num_token_pad_size = graph_pad_size - num_decode_tokens - num_reqs_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) - # For the case when some request reach the max-tokens limit in this forward processing, - # so in this forward new_tokens scheduled is less than decode_token_per_req(1 + spec_token_num). - # Details can see PR:https://github.com/vllm-project/vllm/pull/27922 - num_reqs_pad_size = max(0, num_reqs_pad_size) - - padded_seq_lens = seq_lens.tolist( - ) + [pad_value] * num_reqs_pad_size - else: - padded_seq_lens = seq_lens.tolist() - - seq_lens = torch.from_numpy( - np.array(padded_seq_lens).astype(np.int32)) - seq_lens_list = padded_seq_lens - slot_padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, slot_padding]) - block_table_padding = torch.zeros( - (num_reqs_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], - dim=0) - block_table = self._get_graph_runner_block_tables( - num_reqs + num_reqs_pad_size, block_table) - position_padding = torch.zeros(num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat( - [input_positions, position_padding]) - actual_seq_lengths_q = self.pad_actual_seq_len_q( - num_reqs_pad_size, num_reqs, actual_seq_lengths_q, - common_attn_metadata) - else: - seq_lens_list = seq_lens.tolist() - # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) - batch_size = num_decode_tokens + num_token_pad_size - if actual_seq_lengths_q[-1] != batch_size \ - and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - actual_seq_lengths_q[-1] = batch_size - - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendMLATorchairDecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos) - - return self.metadata_cls( # type: ignore - num_actual_tokens=num_actual_tokens, - query_lens=query_lens.tolist(), - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - attn_mask=common_attn_metadata.attn_mask, - attn_state=common_attn_metadata.attn_state, - prefill=prefill_metadata, - decode=decode_metadata, - query_start_loc=query_start_loc, - block_tables=block_table, - seq_lens=seq_lens, - ) - - def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, - actual_seq_lengths_q, common_attn_metadata): - """ - Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request - in order to meet the requirement of npu_fused_infer_attention_score. - - In Torchair scenario, the lengths of the queries must be padded to the same length. - And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). - - For example: - batch_size=36, num_reqs_pad_size=2, num_reqs=16 - By default, each request should have inference 2 token, which means actual_seq_lengths_q should be - [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. - - However, mtp torchair + PD scenario, the actual_seq_lengths_q may be - [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. - In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. - after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] - """ - FIA_SEQ_LEN_LIMIT = 16 - need_padding = num_reqs_pad_size != 0 and \ - len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ - common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT - if need_padding: - padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] - start_val = actual_seq_lengths_q[-1] - end_val = padding_seq_len_q[-1] - - num_step = len(padding_seq_len_q) - interpolated = np.round( - np.linspace(start_val, end_val, - num_step + 1)[1:]).astype(int).tolist() - assert interpolated[-1] == end_val - assert len(interpolated) == len(padding_seq_len_q) - actual_seq_lengths_q = actual_seq_lengths_q + interpolated - else: - actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] - - return actual_seq_lengths_q - - -class AscendMLATorchairImpl(MLAAttentionImpl): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - **kwargs, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - # MLA Args - self.q_lora_rank = kwargs['q_lora_rank'] - self.kv_lora_rank = kwargs['kv_lora_rank'] - self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] - self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] - self.qk_head_dim = kwargs['qk_head_dim'] - self.v_head_dim = kwargs['v_head_dim'] - self.rotary_emb = kwargs['rotary_emb'] - self.q_proj = kwargs['q_proj'] - self.kv_b_proj = kwargs['kv_b_proj'] - self.o_proj = kwargs['o_proj'] - self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) - self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.tp_size = get_tensor_model_parallel_world_size() - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.running_in_graph = False - self.prefill_mask = None - self.ring_mla_mask_size = 512 - - self.speculative_config = get_current_vllm_config().speculative_config - - def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - if hasattr(self, "running_in_graph") and not self.running_in_graph: - return x - MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB - maybe_npu_prefetch(self.o_proj.weight, - x, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) - return self.o_proj(x, is_prefill=False)[0] - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) - - def _compute_prefill_context( - self, - query: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], - rope_dim: int, - attn_metadata: AscendMLATorchairMetadata, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - ): - assert len(kv_c_and_k_pe_cache) > 1 - prefill_metadata = attn_metadata.prefill - if prefill_metadata is None or prefill_metadata.chunked_context is None: - return prefix_output, prefix_lse - - iters = len(prefill_metadata.chunked_context.seq_tot) - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] - - current_seq_len = torch.tensor(prefill_metadata.query_lens, - dtype=torch.int32) - cache_kv_c = kv_c_and_k_pe_cache[0] - cache_k_pe = kv_c_and_k_pe_cache[1] - num_heads = cache_k_pe.size(2) - latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) - for i in range(iters): - toks = prefill_metadata.chunked_context.seq_tot[i] - - context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ - i] - context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ - i] - seq_len = torch.stack([current_seq_len, context_seq_len]) - kv_c_normed = torch.empty(toks, - num_heads, - latent_kv_dim, - dtype=query.dtype, - device=query.device) - k_pe = torch.empty(toks, - num_heads, - rope_dim, - dtype=query.dtype, - device=query.device) - - torch_npu.atb.npu_paged_cache_load( - cache_kv_c, - cache_k_pe, - prefill_metadata.block_table, - context_seq_len_npu, - seq_starts=prefill_metadata.chunked_context.starts[i], - key=kv_c_normed, - value=k_pe, - ) - - kv_c_normed = kv_c_normed.squeeze() - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=self.prefill_mask, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=prefix_output, - prev_lse=prefix_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=prefix_output, - softmax_lse=prefix_lse) - return prefix_output, prefix_lse - - def _forward_prefill( - self, - query: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], - attn_metadata: AscendMLATorchairMetadata, - ) -> torch.Tensor: - assert attn_metadata.prefill is not None - assert len(kv_c_and_k_pe_cache) > 1 - - num_tokens = query.size(0) - attn_output = torch.empty(num_tokens, - self.num_heads, - self.v_head_dim, - dtype=query.dtype, - device=query.device) - attn_lse = torch.empty(self.num_heads, - num_tokens, - dtype=torch.float32, - device=query.device) - k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache - q_pe = query[..., self.qk_nope_head_dim:] - q_nope = query[..., :self.qk_nope_head_dim] - if self.prefill_mask is None: - if q_nope.dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - prefill_mask = torch.triu( - torch.ones(self.ring_mla_mask_size, - self.ring_mla_mask_size, - device=q_nope.device, - dtype=q_nope.dtype), 1) - self.prefill_mask = torch.where(prefill_mask == 1, mask_value, - 0).to(q_nope.dtype) - torch_npu.atb.npu_ring_mla(q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=value, - mask=self.prefill_mask, - seqlen=attn_metadata.prefill.query_lens, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=None, - prev_lse=None, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="mask_type_triu", - input_layout="type_bsnd", - calc_type="calc_type_first_ring", - output=attn_output, - softmax_lse=attn_lse) - attn_output, attn_lse = self._compute_prefill_context( \ - query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) - - attn_output = attn_output.reshape( - [num_tokens, self.num_heads * self.v_head_dim]) - - return attn_output - - def exec_kv( - self, - hidden_states: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: Tuple, - slots: torch.Tensor, - ): - - B = hidden_states.shape[0] - N = self.num_kv_heads - S = 1 - kv = self.kv_a_proj_with_mqa(hidden_states)[0] - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - ) - return k_pe, k_nope, kv - - def exec_kv_prefill( - self, - hidden_states: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: Tuple, - slots: torch.Tensor, - ): - - B = hidden_states.shape[0] - N = self.num_kv_heads - S = 1 - kv = self.kv_a_proj_with_mqa(hidden_states)[0] - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - is_output_kv=True, - ) - return k_pe, k_nope - - def rope_single( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - B, N, D = x.shape - S = 1 - x = x.view(B, N, S, D) - x = torch_npu.npu_interleave_rope(x, cos, sin) - return x.view(B, N, D) - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - k_nope: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], - attn_metadata: AscendMLATorchairMetadata, - enable_multistream_mla: bool = False, - ) -> torch.Tensor: - decode_meta = attn_metadata.decode - assert decode_meta is not None - num_tokens = q_nope.size(0) - if self.running_in_graph or self.running_chunkprefilll_with_torchair: - # shape of knope/k_pe for npu graph mode should be: - # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] - block_size = kv_c_and_k_pe_cache[0].shape[1] - actual_seq_lengths = None - if self.enable_kv_nz: - k_nope = k_nope.view(-1, self.num_kv_heads, - self.kv_lora_rank // 16, block_size, 16) - k_pe = k_pe.view(-1, self.num_kv_heads, - self.qk_rope_head_dim // 16, block_size, 16) - input_layout = "BSND" - else: - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) - input_layout = "BNSD" - - if attn_metadata.attn_state in [ - AscendAttentionState.SpecDecoding, - AscendAttentionState.ChunkedPrefill - ] and self.speculative_config is not None: - # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill - input_layout = "TND" - # [bs * q_seq_len, num_heads_per_rank, dim] - q_nope = q_nope.view(num_tokens, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, -1) - sparse_mode = 3 - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore - actual_seq_lengths = decode_meta.actual_seq_lengths_q - else: - if self.enable_kv_nz: - q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) - q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) - else: - q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) - sparse_mode = 0 - spec_attn_mask = None - - attn_output, _ = torch_npu.npu_fused_infer_attention_score( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout=input_layout, - atten_mask=spec_attn_mask, - sparse_mode=sparse_mode, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=decode_meta.block_table, - block_size=block_size, - actual_seq_lengths_kv=decode_meta.seq_lens_list, - actual_seq_lengths=actual_seq_lengths) - else: - # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will - # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become - # public available - assert len(kv_c_and_k_pe_cache) > 1 - if envs_ascend.VLLM_ASCEND_MLA_PA: - attn_output = torch_npu.atb.npu_multi_head_latent_attention( - q_nope, q_pe, kv_c_and_k_pe_cache[0], - kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, self.num_heads, self.scale, - self.num_kv_heads) - else: - q = torch.cat([q_nope, q_pe], dim=-1) - attn_output = torch.empty( - [num_tokens, self.num_heads, self.kv_lora_rank], - dtype=q.dtype, - device=q.device) - k_cache = torch.cat( - [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) - torch_npu._npu_paged_attention_mla( - query=q, - key_cache=k_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.decode. - block_table, # type:ignore - context_lens=attn_metadata.decode.seq_lens, # type:ignore - mla_vheadsize=self.kv_lora_rank, - out=attn_output) - - return self._v_up_proj_and_o_proj(attn_output, enable_multistream_mla) - - def forward( - self, - layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn - hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: Tuple[torch.Tensor], - attn_metadata: M, - output: Optional[torch.Tensor] = None, - enable_multistream_mla: bool = False, - ckq: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: - # Profiling run. - return output.fill_(0) - self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill - num_actual_toks = attn_metadata.num_actual_tokens - if k_pe is None and not self.running_in_graph: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - else: - kv_c_normed = hidden_states_or_kv_c_normed - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - if not self.running_in_graph: - # Inputs and outputs may be padded for CUDA graphs - output_padded = output - output = output[:num_actual_toks, ...] - if not self.torchair_graph_enabled: - kv_c_normed = kv_c_normed[:num_actual_toks, ...] - prefill_k_c_normed = kv_c_normed[num_decode_tokens:] - if not self.running_in_graph: - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] - # if not self.torchair_graph_enabled: - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] - else: - decode_hs_or_q_c = hidden_states_or_q_c - if has_decode: - decode_k_nope = None - assert attn_metadata.decode is not None - if self.running_in_graph or self.running_chunkprefilll_with_torchair: - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin - if self.running_chunkprefilll_with_torchair: - decode_hs = ( - hidden_states_or_kv_c_normed[:num_decode_tokens]) - slots = attn_metadata.slot_mapping[:num_decode_tokens] - decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( - decode_hs, cos, sin, kv_cache, slots) - else: - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - npu_wait_tensor(hidden_states_or_kv_c_normed, - ckq, - enabled=enable_multistream_mla) - decode_k_pe, decode_k_nope, decode_kv = self.exec_kv( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) - # Without explicitly controlling the order, IndexByTensor operations - # would be placed after `matmul W_KV_T` hindering the overlapping of - # KvRmsNormRopeCache and SingleRope. - npu_wait_tensor(decode_hs_or_q_c, - cos, - enabled=enable_multistream_mla) - npu_wait_tensor(decode_hs_or_q_c, - sin, - enabled=enable_multistream_mla) - npu_wait_tensor(decode_hs_or_q_c, - decode_kv, - enabled=enable_multistream_mla) - - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) - if self.running_in_graph: - with npu_stream_switch("mla_secondary", - 0, - enabled=enable_multistream_mla): - npu_wait_tensor(decode_q_pe, - decode_k_pe, - enabled=enable_multistream_mla) - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - elif self.running_chunkprefilll_with_torchair: - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - else: - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), decode_k_pe) - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] - if self.torchair_graph_enabled: - num_tokens = prefill_hs_or_q_c.shape[0] - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - - prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) - prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - prefill_hs, cos, sin, kv_cache, - attn_metadata.slot_mapping[num_decode_tokens:]) - - kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope - prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, - -1) - prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1) - else: - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), prefill_k_pe) - - assert len( - kv_cache - ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" - if self.torchair_graph_enabled: - if kv_cache[0].numel() > 0 and has_prefill: - slots = attn_metadata.slot_mapping - # NOTE: Separate the kv cache in advance to avoid OOM or other issues - torch_npu._npu_reshape_and_cache( - key=kv_c_normed.view(num_tokens, self.num_kv_heads, -1), - value=prefill_k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=slots[num_decode_tokens:]) - else: - kv_c_normed = kv_c_normed.view( - [num_actual_toks, self.num_kv_heads, -1]) - torch_npu._npu_reshape_and_cache( - key=kv_c_normed, - value=k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=attn_metadata.slot_mapping) - if not self.running_in_graph: - o_proj_input_shape = (num_actual_toks, - self.num_heads * self.v_head_dim) - o_proj_input = torch.empty(o_proj_input_shape, - dtype=hidden_states_or_q_c.dtype, - device=hidden_states_or_q_c.device) - if has_prefill: - # FIX: aicore move should be also placed on the comm stream in dbo, - # otherwise it may affect the accuracy - # TODO: use an elegant way to overlap - output_prefill = self._forward_prefill(prefill_q, - prefill_k_c_normed, - prefill_k_pe, kv_cache, - attn_metadata) - o_proj_input[num_decode_tokens:] = output_prefill - - if has_decode: - if self.running_in_graph: - return self._forward_decode(decode_ql_nope, decode_q_pe, - decode_k_nope, decode_k_pe, - kv_cache, attn_metadata, - enable_multistream_mla) - else: - output_decode = self._forward_decode(decode_ql_nope, - decode_q_pe, - decode_k_nope, - decode_k_pe, kv_cache, - attn_metadata) - o_proj_input[:num_decode_tokens] = output_decode - - MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB - - maybe_npu_prefetch(self.o_proj.weight, - o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=enable_multistream_mla) - - output[...] = self.o_proj( - o_proj_input, - is_prefill=True, - is_force_scatter=self.enable_shared_expert_dp)[0] - - del o_proj_input - return output_padded diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py deleted file mode 100644 index 012183e2..00000000 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ /dev/null @@ -1,574 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py -# isort: skip_file - -import math -import types -from typing import Any, Optional - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -import torch_npu -from vllm.config import CUDAGraphMode, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_dp_group -from vllm.forward_context import get_forward_context -from vllm.logger import logger - -import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.platform import NPUPlatform -from vllm_ascend.spec_decode import get_spec_decode_method -from vllm_ascend.torchair.utils import ( - TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata, - check_torchair_cache_exist, converting_weight_acl_format, - register_torchair_model, torchair_ops_patch, - torchair_quant_method_register, write_kv_cache_bytes_to_file) -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - AscendDeviceType, get_ascend_device_type) -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner - - -class NPUTorchairModelRunner(NPUModelRunner): - - def __init__(self, vllm_config: VllmConfig, device: torch.device): - self.ascend_config = get_ascend_config() - self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp - super().__init__(vllm_config, device) - if self.speculative_config: - self.actual_seq_lengths_q = list( - range(self.decode_token_per_req, self.max_num_tokens + 1, - self.decode_token_per_req)) - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - None, None, vllm_config, device) - self.use_sparse = hasattr(self.model_config.hf_config, "index_topk") - - register_torchair_model() - torchair_ops_patch() - torchair_quant_method_register() - if self.enable_shared_expert_dp: - return - self.new_kv_cache_bytes = -1 - self.torchair_compiled_model = None # type: ignore - self.torchair_compiled_models = {} # type: ignore - self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph - self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes - self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes - if self.ascend_config.torchair_graph_config.graph_batch_sizes_init: - self.init_torchair_graph_batch_sizes() - - self.update_torchair_graph_batch_sizes() - - torch._dynamo.cache_size.config.cache_size_limit += len( - self.torchair_graph_batch_sizes) - torch._dynamo.config.capture_dynamic_output_shape_ops = True - torch._logging.set_logs( - recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) - - self._check_batch_sizes_consistency() - - def _set_up_drafter(self): - super()._set_up_drafter() - if self.speculative_config: - # Torchair do not support disable_padded_drafter_batch - # Enforce to disable this feature - self.speculative_config.disable_padded_drafter_batch = True - - def _get_drafter(self): - return get_spec_decode_method(self.speculative_config.method, - self.vllm_config, - self.device, - self, - is_torchair_graph=True) - - def _may_pad_kv_consumer_num_seq(self): - # pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens - # self.max_num_reqs here is greater than the actual maximum request number - if self.decode_token_per_req > 1 and self.is_kv_consumer: - # applied only when speculative decoding is active - FIA_SEQ_LEN_LIMIT = 16 - new_max_num_reqs = self.max_num_reqs + math.ceil( - self.max_num_reqs / FIA_SEQ_LEN_LIMIT) + math.ceil( - (self.max_num_reqs * self.decode_token_per_req) / - (FIA_SEQ_LEN_LIMIT**2)) - if self.max_num_reqs < new_max_num_reqs: - logger.warning( - f"max_num_reqs is updated to {new_max_num_reqs}") - self.max_num_reqs = new_max_num_reqs - - def _init_mc2_tokens_capacity(self): - # NOTE: To be clear, we need to make sure that during graph capture, the number of - # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes, - # the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512). - max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len - tp_size = self.parallel_config.tensor_parallel_size - # Use integer arithmetic for ceiling division. - max_graph_batch_size = self.calculate_new_torchair_graph_batch_size( - max_num_tokens, tp_size) - self.mc2_tokens_capacity = max_graph_batch_size - - if get_ascend_device_type( - ) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512: - logger.error( - f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}" - ) - if get_ascend_device_type( - ) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256: - logger.error( - f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}" - ) - - def _sync_metadata_across_dp( - self, num_tokens: int, - with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]: - """Override from NPUModelRunner to pad num_tokens""" - if self.enable_shared_expert_dp: - # Padding is not required for shared_expert_dp cases in eager mode. - return num_tokens, None, with_prefill - if self.dp_size == 1: - if not with_prefill: - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( - num_tokens) - return maybe_padded_num_tokens, None, with_prefill - return num_tokens, None, with_prefill - - num_tokens_across_dp = torch.zeros(self.dp_size + 1, - dtype=torch.int32, - device="npu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-1] = int(with_prefill) - dist.all_reduce(num_tokens_across_dp, - group=get_dp_group().device_group) - with_prefill = bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-1] - - if not with_prefill: - max_num_token = num_tokens_across_dp.max().item() - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( - max_num_token) - num_tokens_across_dp = torch.full((self.dp_size, ), - maybe_padded_num_tokens, - dtype=torch.int32, - device="npu") - else: - maybe_padded_num_tokens = num_tokens - - return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill - - def _build_dummy_attn_metadata( - self, - with_prefill: bool, - num_reqs: int, - num_tokens: int, - max_query_len: int, - num_scheduled_tokens: np.ndarray, - aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, - force_attention: bool = False, - ) -> Optional[dict[str, Any]]: - # NOTE: If torchair graph mode and not with_prefill, - # we can't skip_attn, it will cause graph recompile. - if with_prefill or self.enable_shared_expert_dp: - attn_metadata = super()._build_dummy_attn_metadata( - with_prefill, num_reqs, num_tokens, max_query_len, - num_scheduled_tokens, aclgraph_runtime_mode, force_attention) - else: - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=num_reqs, - num_actual_tokens=1, - actual_seq_lengths_q=self.actual_seq_lengths_q, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - decode_token_per_req=self.decode_token_per_req, - ) - attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( - common_attn_metadata) - return attn_metadata - - def _generate_dummy_run_hidden_states(self, with_prefill, - is_torchair_compile, input_ids, - positions, attn_metadata, num_tokens, - intermediate_tensors, inputs_embeds): - if with_prefill or self.enable_shared_expert_dp: - if get_ascend_device_type() == AscendDeviceType._310P: - converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) - hidden_states = super()._generate_dummy_run_hidden_states( - with_prefill, is_torchair_compile, input_ids, positions, - attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) - else: - # Only mark static while compiling - if is_torchair_compile: - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static(attn_metadata.decode.input_positions) - torch._dynamo.mark_static(get_forward_context().mc2_mask) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - if self.speculative_config: - torch._dynamo.mark_static(attn_metadata.decode.attn_mask) - for kv in self.kv_caches: - assert isinstance(kv, tuple), "kv_cache must be a tuple" - torch._dynamo.mark_static(kv[0]) - torch._dynamo.mark_static(kv[1]) - if get_ascend_device_type() == AscendDeviceType._310P: - converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ) - - compiled_model = self._get_torchair_lazy_compiled_model(num_tokens) - model_kwargs = {} - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - **model_kwargs, - ) - return hidden_states - - def _convert_torch_format(self, kv_cache): - if self.enable_shared_expert_dp: - return super()._convert_torch_format(kv_cache) - kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND) - return kv_cache - - def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None: - # Trigger torchair graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens, is_torchair_compile=True) - self._dummy_run(num_tokens, is_torchair_compile=True) - logger.info("Batchsize %d is compiled successfully: %d/%d.", - num_tokens, idx + 1, len(torchair_graph_batch_sizes)) - - def _capture_model(self): - """Override from NPUModelRunner to use torchair graph capture.""" - if self.enable_shared_expert_dp: - return super()._capture_model() - # TODO(NeverRaR): Calling graph_capture(device=self.device) in - # torchair graph capture can cause some issues, so now we just - # temporarily split the codepath for the two different graph patterns. - torchair_graph_batch_sizes = self.torchair_graph_batch_sizes - graph_num = len(torchair_graph_batch_sizes) - - if self.use_cached_npu_graph and not check_torchair_cache_exist(): - # If caching is enabled but does not exist (either - # use_cached_kv_cache_bytes is disabled or kv_cache_bytes are - # different), we will compile the model twice. The first time is - # used to generate the cache, and the second time is used to load the - # cache to skip the overhead caused by Dynamo guard mechanism. - logger.info( - "Cache compilation for torchair graph is enabled. Now we compile graph to genetate" - " torchair cache, this usually takes %.1f~%.1f mins.", - 0.5 * graph_num, 1.5 * graph_num) - self._compile_torchair_graph(torchair_graph_batch_sizes) - NPUPlatform.synchronize() - # Note: We reset dynamo and reload the compiled torchair cached computation graph below - # that was compiled above. This operation reduces graph launch time by 2-4ms and avoids - # runtime errors caused by configuration mismatches in graph mode. - torch._dynamo.reset() - self.torchair_compiled_models.clear() - if self.use_cached_npu_graph: - logger.info( - "Loading torchair graph cache, this usually takes %.1f~%.1f mins.", - 0.3 * graph_num, 0.5 * graph_num) - self._compile_torchair_graph(torchair_graph_batch_sizes) - else: - logger.info( - "Capturing torchair graph, this usually takes %.1f~%.1f mins.", - 0.5 * graph_num, 1.5 * graph_num) - self._compile_torchair_graph(torchair_graph_batch_sizes) - - if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0: - write_kv_cache_bytes_to_file(torch.distributed.get_rank(), - self.new_kv_cache_bytes) - - def _use_aclgraph(self) -> bool: - if self.enable_shared_expert_dp: - return super()._use_aclgraph() - return False - - def _check_batch_sizes_consistency(self) -> None: - if not dist.is_initialized(): - return - - local = torch.tensor(self.torchair_graph_batch_sizes, - device="cpu", - dtype=torch.int32) - gathered_graph_batch_size = local.clone() - dist.all_reduce(gathered_graph_batch_size, - group=get_dp_group().cpu_group) - expected = local * self.dp_size - - if not torch.equal(gathered_graph_batch_size, expected): - diff_idxs = (gathered_graph_batch_size != expected).nonzero( - as_tuple=False).flatten().tolist() - raise AssertionError( - f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n" - f"Local (rank {self.dp_rank}): {local.tolist()}\n" - f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n" - f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}" - ) - - def _update_graph_pad_size(self, with_prefill, graph_pad_size): - if with_prefill or self.enable_shared_expert_dp: - super()._update_graph_pad_size(with_prefill, graph_pad_size) - else: - self.graph_pad_size = graph_pad_size - - def _update_input_ids_and_positions(self, input_ids, positions, - num_input_tokens, with_prefill, - padded_num_tokens_across_dp): - """Override from NPUModelRunner to update input_ids and positions""" - input_ids, positions = super()._update_input_ids_and_positions( - input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp) - - if with_prefill or self.enable_shared_expert_dp: - return input_ids, positions - else: - input_ids = self.input_ids[:padded_num_tokens_across_dp] - positions = self.positions[:padded_num_tokens_across_dp] - return input_ids, positions - - def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, - padded_num_tokens_across_dp, - input_ids, positions, - intermediate_tensors, - inputs_embeds): - if attn_metadata is not None and isinstance(attn_metadata, dict): - attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] - - if self.enable_shared_expert_dp: - return super()._generate_process_reqs_hidden_states( - attn_metadata, with_prefill, padded_num_tokens_across_dp, - input_ids, positions, intermediate_tensors, inputs_embeds) - model_kwargs = { - "kv_caches": self.kv_caches, - "attn_metadata": attn_metadata - } - if not with_prefill: - if get_ascend_device_type() == AscendDeviceType._310P: - converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ) - compiled_model = self._get_torchair_lazy_compiled_model( - padded_num_tokens_across_dp) - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - else: - assert self.model is not None - if get_ascend_device_type() == AscendDeviceType._310P: - converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - return hidden_states - - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - - patch_for_hcom() - - if get_ascend_device_type() == AscendDeviceType._310P: - # on 300I Duo platform, we need to patch broadcast. however, this patch will be - # overwritten by patch_for_hcom in torchair. so we need to re-patch it here. - from vllm_ascend.patch.platform.patch_distributed import \ - communication_adaptation_310p - communication_adaptation_310p() - - config = torchair.CompilerConfig() - if self.ascend_config.torchair_graph_config.mode: - config.mode = self.ascend_config.torchair_graph_config.mode - config.experimental_config.frozen_parameter = \ - self.ascend_config.torchair_graph_config.enable_frozen_parameter - # enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to - # disable it on 300I Duo platform now. - config.experimental_config.tiling_schedule_optimize = get_ascend_device_type( - ) != AscendDeviceType._310P - config.experimental_config.enable_view_optimize = \ - self.ascend_config.torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=not self.use_sparse, - fullgraph=True, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=not self.use_sparse, - fullgraph=True, - cache_dir=TORCHAIR_CACHE_DIR, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] - - def init_torchair_graph_batch_sizes(self): - start_graph_batch_size = 4 - tp_size = get_tensor_model_parallel_world_size() - - # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks - start_graph_batch_size = max(start_graph_batch_size, tp_size) - - while (start_graph_batch_size <= self.max_num_reqs): - self.torchair_graph_batch_sizes.append(start_graph_batch_size) - start_graph_batch_size *= 2 - - def calculate_new_torchair_graph_batch_size(self, old_graph_batch_size, - tp_size): - cur_graph_batch_size = (old_graph_batch_size + tp_size - - 1) // tp_size * tp_size - # MTP > 1: Cal LCMLeast Common Multiple with graph_batch_size and tp_size, - # Both adapter multi-dp and FIA operator - if self.speculative_config is not None and self.speculative_config.num_speculative_tokens > 1: - cur_graph_batch_size = (tp_size * old_graph_batch_size) \ - // math.gcd(tp_size, old_graph_batch_size) - return cur_graph_batch_size - - def select_torchair_padded_batch_size(self, batch_size: int): - for padded_batch_size in self.torchair_graph_batch_sizes: - if batch_size <= padded_batch_size: - # we treat batch_size as num of requests - return padded_batch_size - raise ValueError( - f"cur batch_size is invalid, torchair_graph_batch_sizes is " - f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}." - ) - - def update_torchair_graph_batch_sizes(self): - # return graph_batch_sizes according to the max number of tokens - # first pad according to the number of requests - if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'mtp': - # pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs - self.torchair_graph_batch_sizes = [self.max_num_reqs] - logger.warning( - f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}" - ) - elif len(self.torchair_graph_batch_sizes) == 0: - self.torchair_graph_batch_sizes = [1, self.max_num_reqs] - else: - self.torchair_graph_batch_sizes = sorted( - self.torchair_graph_batch_sizes) - while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs: - self.torchair_graph_batch_sizes.pop() - if len(self.torchair_graph_batch_sizes) == 0: - logger.warning( - "torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]" - ) - self.torchair_graph_batch_sizes = [1, self.max_num_reqs] - if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: - self.torchair_graph_batch_sizes.append(self.max_num_reqs) - - # padded max number tokens = max_num_req * decode_token_per_req - self.torchair_graph_batch_sizes = [ - graph_batch_size * self.decode_token_per_req - for graph_batch_size in self.torchair_graph_batch_sizes - ] - - # NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size` - # Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same - # on all EP ranks - if get_ascend_device_type( - ) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel: - self._align_graph_size_divisible_by_tp_size() - - def _align_graph_size_divisible_by_tp_size(self): - tp_size = self.parallel_config.tensor_parallel_size - new_graph_batch_sizes = [] - for graph_batch_size in self.torchair_graph_batch_sizes: - cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size( - graph_batch_size, tp_size) - if cur_graph_batch_size not in new_graph_batch_sizes and \ - cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: - new_graph_batch_sizes.append(cur_graph_batch_size) - elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ - and self.decode_token_per_req > 1: - logger.warning( - f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", - f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." - ) - new_max_num_reqs = math.ceil( - max(new_graph_batch_sizes) / self.decode_token_per_req) - if self.max_num_reqs != new_max_num_reqs: - logger.warning(f"max_num_reqs is updated to {new_max_num_reqs}") - self.max_num_reqs = new_max_num_reqs - if not (self.decode_token_per_req > 1 and self.is_kv_consumer): - # Do not update scheduler_config.max_num_seqs in KV consumer + MTP - # Since FIA need extra space for padding - # Enforce self.max_num_seqs > self.scheduler_config.max_num_seqs in KV consumer + MTP - self.scheduler_config.max_num_seqs = new_max_num_reqs - - if new_graph_batch_sizes != self.torchair_graph_batch_sizes: - logger.warning( - f"torchair_graph_batch_sizes are updated to {new_graph_batch_sizes}." - ) - self.torchair_graph_batch_sizes = new_graph_batch_sizes - - def _build_drafter_prepare_inputs_torchair_param(self): - if self.enable_shared_expert_dp: - return super()._build_drafter_prepare_inputs_torchair_param() - else: - return True diff --git a/vllm_ascend/torchair/torchair_mtp_proposer.py b/vllm_ascend/torchair/torchair_mtp_proposer.py deleted file mode 100644 index e06c3c57..00000000 --- a/vllm_ascend/torchair/torchair_mtp_proposer.py +++ /dev/null @@ -1,543 +0,0 @@ -import types - -import torch -import torch.nn as nn -import torchair -from torchair import patch_for_hcom -from vllm.config import (CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import \ - process_weights_after_loading -from vllm.utils.torch_utils import set_default_torch_dtype -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.spec_decode import MtpProposer -from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ - TorchairDeepSeekMTP -from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, - TorchairCommonAttentionMetadata) -from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable - -PADDING_SLOT_ID = -1 - - -class TorchairMtpProposer(MtpProposer): - - def __init__( - self, - vllm_config: VllmConfig, - device, - runner, - ): - super().__init__(vllm_config, device, runner) - self.torchair_compiled_model = None # type: ignore - self.torchair_compiled_models = {} # type: ignore - - def load_model(self, model) -> None: - loader = get_model_loader(self.vllm_config.load_config) - - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - target_device = self.vllm_config.device_config.device - - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - - self.model = TorchairDeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - - draft_attn_layer_names = (get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase).keys() - - target_attn_layer_names) - - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = list(draft_attn_layer_names) - - self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) - process_weights_after_loading(self.model, draft_model_config, - target_device) - - @torch.inference_mode() - def dummy_run(self, - num_tokens: int, - with_prefill: bool = False, - skip_attn: bool = False, - num_reqs: int = 0, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None) -> None: - moe_comm_type = self.runner._select_moe_comm_method(num_tokens) - - if not with_prefill: - skip_attn = False - if skip_attn: - attn_metadata = None - else: - common_attn_metadata = TorchairCommonAttentionMetadata( - num_reqs=num_reqs, - num_actual_tokens=1, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - decode_token_per_req=self.runner.decode_token_per_req, - ) - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - common_attn_metadata) - - input_ids = self.input_ids[:num_tokens] - positions = self.positions[:num_tokens] - previous_hidden_states = self.hidden_states[:num_tokens] - for _ in range(self.num_speculative_tokens): - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - moe_comm_type=moe_comm_type, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=0): - if not with_prefill: - assert attn_metadata is not None - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static(previous_hidden_states) - torch._dynamo.mark_static(attn_metadata.decode.block_table) - torch._dynamo.mark_static( - attn_metadata.decode.input_positions) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(get_forward_context().mc2_mask) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - torch._dynamo.mark_static(attn_metadata.decode.attn_mask) - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_tokens) - torchair_compiled_model( - input_ids=input_ids, - positions=positions, - hidden_states=previous_hidden_states, - inputs_embeds=None, - intermediate_tensors=None, - attn_metadata=attn_metadata, - kv_caches=self.runner.kv_caches[-1:], - spec_step_idx=0) - else: - self.model(input_ids=input_ids, - positions=positions, - hidden_states=previous_hidden_states) - dummy_compute_logits(previous_hidden_states) - if with_prefill: - break - - def generate_token_ids(self, - valid_sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata = None, - scheduler_output: SchedulerOutput = None, - spec_decode_metadata: SpecDecodeMetadata = None, - positions: torch.Tensor = None, - num_scheduled_tokens: int = 0, - hidden_states: torch.Tensor = None, - attn_metadata=None, - aux_hidden_states: torch.Tensor = None): - if attn_metadata is not None and isinstance(attn_metadata, dict): - attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.runner.input_batch.req_ids[i] - req_state = self.runner.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - accepted_token_indices = None - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.runner.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc - else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - cu_num_tokens, accepted_token_indices, target_token_ids, \ - target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs( - attn_metadata.query_start_loc, - num_rejected_tokens, - self.runner.input_ids[:num_scheduled_tokens], - positions[:num_scheduled_tokens], - hidden_states[:num_scheduled_tokens], - attn_metadata.slot_mapping[:num_scheduled_tokens], - ) - - draft_token_ids = self._propose_torchair( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_tables, - sampling_metadata=sampling_metadata, - token_indices=accepted_token_indices) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids - - def _torchair_prepare_inputs( - self, - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, - token_ids: torch.Tensor, - positions: torch.Tensor, - hidden_states: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - - cu_num_tokens = cu_target_query_lens - relative_index = query_len_per_req - num_rejected_tokens - 1 - token_indices = cu_num_tokens[:-1] + relative_index - # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model - target_token_ids = token_ids - target_positions = positions - target_hidden_states = hidden_states - target_slot_mapping = slot_mapping - - return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping - - def _propose_torchair( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, - sampling_metadata: SamplingMetadata, - token_indices=None) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - if token_indices is not None: - last_token_indices = token_indices - - self.input_ids[last_token_indices] = next_token_ids - - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - # FIXME: reorder_batch() needs to be called before build() - # because fields of attn_metadata_builder needs to be updated. - # However, currently reorder_batch() takes input_batch and - # scheduler_output as arguments, we should probably refactor - # the method to use new data structures which are independent - # from input_batch and scheduler_output. - # self.runner.attn_metadata_builder.reorder_batch( - # input_batch=self.runner.input_batch, - # scheduler_output=self.runner.scheduler_output, - # ) - - if not self.runner.with_prefill: - # Torchair graph mode, padding is same as the main model - num_input_tokens = self.runner.graph_pad_size - elif (self.runner.use_aclgraph - and num_tokens <= self.runner.aclgraph_batch_sizes[-1]): - # Acl graph mode, add padding to the batch size - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - # Eager mode, no padding needed - num_input_tokens = num_tokens - - seq_lens = target_positions[last_token_indices] + 1 - seq_lens = seq_lens.int() - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=cu_num_tokens[:batch_size + 1], - query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(), - seq_lens_cpu=seq_lens.cpu(), - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping=target_slot_mapping, - positions=target_positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - graph_pad_size=self.runner.graph_pad_size, - decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=None, - seq_lens=None) - - attn_metadata = self.runner.attn_metadata_builder.build( - 0, common_attn_metadata, self.runner.get_model()) - - self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states - - # torchair mode can reuse self.runner.num_tokens_across_dp - num_tokens_across_dp = self.runner.num_tokens_across_dp - with_prefill = self.runner.with_prefill - moe_comm_type = self.runner._select_moe_comm_method(num_input_tokens) - - for step in range(self.num_speculative_tokens): - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - with_prefill=with_prefill, - num_tokens_across_dp=num_tokens_across_dp, - reserved_mc2_mask=self.runner.reserved_mc2_mask, - moe_comm_type=moe_comm_type, - in_profile_run=self.runner.in_profile_run, - num_actual_tokens=num_tokens): - with ProfileExecuteDuration().capture_async('mtp_forward'): - model_kwargs = {} - model_kwargs["attn_metadata"] = attn_metadata - - model_kwargs["kv_caches"] = self.runner.kv_caches[-1:] - if not self.runner.with_prefill: - torchair_compiled_model = self._get_torchair_lazy_compiled_model( - num_input_tokens) - hidden_states = torchair_compiled_model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self. - hidden_states[:num_input_tokens], - inputs_embeds=None, - intermediate_tensors=None, - spec_step_idx=0, - **model_kwargs) - else: - hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens] - ) - - num_indices = last_token_indices.shape[0] - if lmhead_tp_enable(): - if not self.runner.with_prefill: - max_num_reqs_across_dp = num_input_tokens - else: - max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs - last_token_indices = nn.functional.pad( - last_token_indices, - (0, max_num_reqs_across_dp - num_indices)) - - sample_hidden_states = hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) - if lmhead_tp_enable() and num_indices < logits.shape[0]: - logits = logits[:num_indices] - draft_token_ids = logits.argmax(dim=-1) - - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - if step == 0: - draft_token_ids_list = [draft_token_ids] - else: - draft_token_ids_list.append(draft_token_ids) - - # prepare next mtp inputs - # mtp>1: prefill skip or decode skip last loop - if with_prefill: - for _ in range(self.num_speculative_tokens - 1): - draft_token_ids_list.append(draft_token_ids) - if step == self.num_speculative_tokens - 1 or with_prefill: - break - - attn_metadata_i = attn_metadata - - if step == 0: - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] - slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] - attn_metadata_i.slot_mapping.fill_(-1) - attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] - last_token_indices = self.arange[:batch_size] - if attn_metadata_i.num_decode_tokens != 0: - attn_metadata_i.num_decode_tokens = batch_size - if not self.runner.with_prefill: - attn_metadata_i.num_actual_tokens = batch_size - attn_metadata_i.query_lens = [1] * batch_size - - input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.runner.model_config.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) - # Increment the sequence lengths. - attn_metadata_i.seq_lens[:batch_size] += 1 - # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize their overheads in attention. - exceeds_max_model_len_cpu = exceeds_max_model_len.to( - attn_metadata_i.seq_lens.device, non_blocking=True) - attn_metadata_i.seq_lens[:batch_size].masked_fill_( - exceeds_max_model_len_cpu, 1) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - slot_mapping += 1 - slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - self.hidden_states[:hidden_states.shape[0]] = hidden_states - attn_metadata_i.slot_mapping[:batch_size] = slot_mapping - - if attn_metadata_i.prefill is not None: - attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens - attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( - ) - attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens - attn_metadata_i.prefill.input_positions = self.positions[: - num_input_tokens] - attn_metadata_i.prefill.max_seq_lens += 1 - attn_metadata_i.prefill.max_seq_lens = min( - attn_metadata_i.prefill.max_seq_lens, - self.runner.model_config.max_model_len) - if attn_metadata_i.decode is not None: - attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens - attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( - ) - attn_metadata_i.decode.input_positions = self.positions[: - num_input_tokens] - attn_metadata_i.decode.max_seq_lens += 1 - attn_metadata_i.decode.max_seq_lens = min( - attn_metadata_i.decode.max_seq_lens, - self.runner.model_config.max_model_len) - - # mtp>1: [batch_size, k] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids - - def _get_torchair_lazy_compiled_model(self, batch_size: int): - if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[ - -1]: - raise ValueError( - f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}" - ) - - compiled_model = self.torchair_compiled_models.get( - batch_size - ) if self.runner.use_cached_npu_graph else self.torchair_compiled_model - - if compiled_model: - return compiled_model - - patch_for_hcom() - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.runner.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.torchair_compiled_model = torch.compile( - self.model, - dynamic=not self.use_sparse, - fullgraph=True, - backend=npu_backend) - return self.torchair_compiled_model - else: - # Generate a new forward proxy code object to prevent the invalidation of - # compilation cache caused by dynamo retracing - forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" - forward_fn = self.model.forward - code = forward_fn.__code__ - # Mark code object with a new proxy name - modified_code = code.replace(co_name=forward_proxy_name, ) - - modified_func = types.FunctionType(modified_code, - forward_fn.__globals__, - name=forward_proxy_name, - argdefs=forward_fn.__defaults__) - - self.model.__dict__[forward_proxy_name] = modified_func.__get__( - self.model, nn.Module) - self.torchair_compiled_models[ - batch_size] = torchair.inference.cache_compile( - self.model.__dict__[forward_proxy_name], - dynamic=not self.use_sparse, - fullgraph=True, - cache_dir=TORCHAIR_CACHE_DIR, - config=config, - ge_cache=False) - return self.torchair_compiled_models[batch_size] diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py deleted file mode 100644 index 19e88017..00000000 --- a/vllm_ascend/torchair/torchair_sfa.py +++ /dev/null @@ -1,1317 +0,0 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_npu -from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.utils.math_utils import cdiv, round_down - -import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills) -from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata -from vllm_ascend.utils import is_enable_nz -from vllm_ascend.worker.npu_input_batch import InputBatch - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - - -class AscendSFATorchairBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ASCEND_SFA_TORCHAIR" - - @staticmethod - def get_builder_cls(): - return AscendSFATorchairMetadataBuilder - - #NOTE: is that ok? - @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, - head_size: int) -> tuple[int, ...]: - return (num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def get_impl_cls() -> Type["MLAAttentionImpl"]: - return AscendSFATorchairImpl - - -@dataclass -class AscendSFATorchairPrefillMetadata: - """ Prefill Specific Metadata for Ascend""" - - @dataclass - class TorchairChunkedContextMetadata: - # New for SFA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - chunk_seq_lens: torch.Tensor - - attn_mask: torch.Tensor - query_lens: list[int] # Check!! - seq_lens: list[int] # Check!! - context_lens: torch.Tensor - input_positions: torch.Tensor - query_start_loc: torch.Tensor - block_table: torch.Tensor - max_query_len: int - max_seq_lens: int - sin: torch.Tensor - cos: torch.Tensor - chunked_context: Optional[TorchairChunkedContextMetadata] = None - - -@dataclass -class AscendSFATorchairDecodeMetadata: - # Input positions for rotrary embeddings since for SFA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - seq_lens: torch.Tensor - max_seq_lens: int - seq_lens_list: list[int] - actual_seq_lengths_q: torch.Tensor - sin: torch.Tensor - cos: torch.Tensor - attn_mask: Optional[torch.Tensor] = None - - -@dataclass -class AscendSFATorchairMetadata: - """Metadata for SFACommon. - - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - slot_mapping: torch.Tensor - query_start_loc: torch.Tensor - seq_lens: torch.Tensor - block_tables: torch.Tensor - - # New for SFA (compared to FlashAttention) - # For handling prefill decode split - num_decodes: int - num_decode_tokens: int - num_prefills: int - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - - query_lens: Optional[list[int]] = None - # The dimension of the attention heads - head_dim: Optional[int] = None - attn_mask: torch.Tensor = None - # chunked prefill by default if no attn_states passed - attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - - decode: Optional[AscendSFATorchairDecodeMetadata] = None - prefill: Optional[AscendSFATorchairPrefillMetadata] = None - is_prefill: bool = False - is_decode: bool = False - - def __post_init__(self): - pass - # supported_head_sizes = AscendSFABackend.get_supported_head_sizes() - # if self.head_dim is not None and self.head_dim \ - # not in supported_head_sizes: - # raise ValueError( - # f"Only {supported_head_sizes} are supported for head_dim,", - # f"received {self.head_dim}.") - - -M = TypeVar("M", bound=AscendSFATorchairMetadata) - - -class AscendSFATorchairMetadataBuilder: - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - # _attn_mask_builder = None - def __init__(self, - kv_cache_spec, - layer_names, - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendSFATorchairMetadata] = None): - self.metadata_cls: Optional[AscendSFATorchairMetadata] = metadata_cls \ - if metadata_cls is not None else AscendSFATorchairMetadata # type: ignore - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.device = device - scheduler_config = vllm_config.scheduler_config - self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + - self.block_size - 1) // self.block_size - self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * self.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 SFA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * self.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.cos_cache = None - self.sin_cache = None - - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - # We now want to reorder the batch so that the "decode" requests are at - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - # For torch air graph mode we treat spec decoding as decode. - if self.torchair_graph_enabled: - if num_tokens - num_spec_tokens == 1: - decodes.append(i) - else: - prefills.append(i) - # For eager mode we treat spec decoding as chunked prefill. - else: - if num_tokens == 1: - decodes.append(i) - else: - prefills.append(i) - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - first_prefill = 0 - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: - break - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - return modified_batch - - def _get_graph_runner_block_tables( - self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor: - max_blocks = self.max_blocks - - graph_block_tables = torch.zeros((num_seqs, max_blocks), - dtype=block_tables.dtype, - device=block_tables.device) - - num_blocks = block_tables.size(1) - if num_blocks <= max_blocks: - graph_block_tables[:num_seqs, : - num_blocks] = block_tables[:num_seqs, : - num_blocks] - else: - graph_block_tables[:num_seqs, : - max_blocks] = block_tables[:num_seqs, : - max_blocks] - - return graph_block_tables[:, :max_blocks] - - def build_torchair_graph_dummy( - self, - common_attn_metadata: TorchairCommonAttentionMetadata, - ) -> AscendSFATorchairMetadata: - device = self.device - num_reqs = common_attn_metadata.num_reqs - block_table = torch.zeros((num_reqs, self.max_blocks), - dtype=torch.int32, - device=device) - block_table = self._get_graph_runner_block_tables( - num_reqs, block_table) - num_tokens = num_reqs * common_attn_metadata.decode_token_per_req - seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) - seq_lens_list = [0] * num_reqs - input_positions = torch.zeros(num_tokens, - dtype=torch.int32, - device=device).long() - slot_mapping = torch.full((num_tokens, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - query_start_loc = torch.full((num_reqs, ), - -1, - dtype=torch.int32, - device=device) - sin = torch.ones(num_tokens, - 1, - 1, - self.rope_dim, - dtype=self.model_config.dtype, - device=device) - cos = torch.ones(num_tokens, - 1, - 1, - self.rope_dim, - dtype=self.model_config.dtype, - device=device) - - if self.vllm_config.speculative_config is not None and\ - self.vllm_config.speculative_config.method == 'mtp': - attn_state = AscendAttentionState.SpecDecoding - num_decode_tokens = 2 - else: - attn_state = AscendAttentionState.DecodeOnly - num_decode_tokens = 1 - # cumsum here. - # actual_seq_lengths_q = torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_tokens]).to(torch.int32).npu() - # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).to(torch.int32).npu() - actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( - torch.int32).npu( - ) * common_attn_metadata.decode_token_per_req ############## - decode_metadata = AscendSFATorchairDecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=1, - attn_mask=common_attn_metadata.spec_attn_mask, - # actual_seq_lengths_q=torch.Tensor(common_attn_metadata.actual_seq_lengths_q[:num_reqs]).to(torch.int32).npu(), - actual_seq_lengths_q=actual_seq_lengths_q, - # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu(), - sin=sin, - cos=cos, - ) - return self.metadata_cls( # type: ignore - num_input_tokens=common_attn_metadata.num_actual_tokens, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=num_tokens, - num_decode_tokens=num_decode_tokens, - num_prefills=0, - attn_mask=common_attn_metadata.attn_mask, - attn_state=attn_state, - prefill=None, - decode=decode_metadata, - query_start_loc=query_start_loc, - seq_lens=seq_lens, - block_tables=block_table, - is_prefill=False, - is_decode=True) - - def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, - ) -> AscendSFATorchairMetadata: - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - if self.torchair_graph_enabled and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - decode_threshold = common_attn_metadata.decode_token_per_req - else: - # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding - decode_threshold = 1 - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens - - # Note(simon): be careful about the CPU <> GPU memory movement in this - # function. We should avoid GPU -> CPU sync as much as possible because - # it blocks on all previous kernels. - device = self.device - - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens].to( - device, - non_blocking=True) - input_positions = common_attn_metadata.positions[: - num_actual_tokens].long( - ) - - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - 0].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - 0].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - - # check CPU operation here - query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - query_lens = query_seq_lens_cpu[:num_reqs] - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = (seq_lens - query_lens) - - prefill_metadata = None - chunked_context_metadata = None - is_prefill = False - is_decode = False - if num_prefills > 0: - reqs_start = num_decodes # prefill_start - tokens_start = num_decode_tokens - max_query_len = query_lens[tokens_start:].max().item() - max_seq_lens = seq_lens[tokens_start:].max().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] - - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - max_context_len_cpu = context_lens_cpu.max().item() - num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - if self.chunked_prefill_enabled and max_context_len_cpu > 0: - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) - max_context_chunk = round_down(max_context_chunk, - self.block_size) - - assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - chunked_context_metadata = \ - AscendSFATorchairPrefillMetadata.TorchairChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - workspace=self.chunked_prefill_workspace, - ) - prefill_input_positions = input_positions[tokens_start:] - cos = self.cos_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - actual_query_lens = torch.tensor( - query_lens[tokens_start:], - dtype=torch.int32).npu() # int64->int32 - query_lens_prefill_sfa = torch.cumsum(actual_query_lens, - dim=0).to(torch.int32).npu() - seq_lens_prefill_sfa = torch.tensor(seq_lens, - dtype=torch.int32).npu() - prefill_metadata = AscendSFATorchairPrefillMetadata( - attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens_prefill_sfa, - seq_lens=seq_lens_prefill_sfa, - context_lens=seq_lens[tokens_start:], - input_positions=prefill_input_positions, - block_table=block_table[reqs_start:, ...], - max_query_len=max_query_len, - max_seq_lens=max_seq_lens, - query_start_loc=prefill_query_start_loc, - chunked_context=chunked_context_metadata, - sin=sin, - cos=cos, - ) - is_prefill = True - - decode_metadata = None - graph_pad_size = common_attn_metadata.graph_pad_size - use_torchair_graph = graph_pad_size != -1 - if num_decodes > 0: - # Check here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( - torch.int32).npu() - max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() - # input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes, ...] - num_token_pad_size = 0 - if use_torchair_graph and common_attn_metadata.attn_state in [ - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding - ]: - num_reqs_pad_size = 0 - if graph_pad_size != 0: - pad_value = 0 - num_token_pad_size = graph_pad_size - num_decode_tokens - num_reqs_pad_size = ( - graph_pad_size // - common_attn_metadata.decode_token_per_req - num_reqs) - padded_seq_lens = seq_lens.tolist( - ) + [pad_value] * num_reqs_pad_size - else: - padded_seq_lens = seq_lens.tolist() - - seq_lens = torch.from_numpy( - np.array(padded_seq_lens).astype(np.int32)).npu() - seq_lens_list = padded_seq_lens - slot_padding = torch.full((num_token_pad_size, ), - PAD_SLOT_ID, - dtype=slot_mapping.dtype, - device=slot_mapping.device) - slot_mapping = torch.cat([slot_mapping, slot_padding]) - block_table_padding = torch.zeros( - (num_reqs_pad_size, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], - dim=0) - block_table = self._get_graph_runner_block_tables( - num_reqs + num_reqs_pad_size, block_table) - position_padding = torch.zeros(num_token_pad_size, - dtype=input_positions.dtype, - device=input_positions.device) - input_positions = torch.cat( - [input_positions, position_padding]) - - # actual_seq_lengths_q = torch.cumsum(actual_seq_lengths_q, dim=0).npu() - # actual_seq_lengths_q=torch.Tensor([1]).to(torch.int32).npu() - actual_seq_lengths_q = torch.arange(1, num_reqs + 1).to( - torch.int32).npu( - ) * common_attn_metadata.decode_token_per_req - # MTP ignored - # actual_seq_lengths_q = self.pad_actual_seq_len_q( - # num_reqs_pad_size, num_reqs, actual_seq_lengths_q, - # common_attn_metadata) - else: - seq_lens_list = seq_lens.tolist() - # mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens) - batch_size = num_decode_tokens + num_token_pad_size - if actual_seq_lengths_q[-1] != batch_size \ - and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding: - actual_seq_lengths_q[-1] = batch_size - - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - padded_token_num = input_positions.shape[0] - actual_seq_lengths_q = torch.arange( - 1, - (padded_token_num // common_attn_metadata.decode_token_per_req) - + 1).to(torch.int32).npu( - ) * common_attn_metadata.decode_token_per_req - decode_metadata = AscendSFATorchairDecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos) - is_decode = True - - return self.metadata_cls( # type: ignore - num_actual_tokens=num_actual_tokens, - query_lens=query_lens.tolist(), - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - attn_mask=common_attn_metadata.attn_mask, - attn_state=common_attn_metadata.attn_state, - prefill=prefill_metadata, - decode=decode_metadata, - query_start_loc=query_start_loc, - block_tables=block_table, - seq_lens=seq_lens, - is_prefill=is_prefill, - is_decode=is_decode) - - def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs, - actual_seq_lengths_q, common_attn_metadata): - """ - Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request - in order to meet the requirement of npu_fused_infer_attention_score. - - In Torchair scenario, the lengths of the queries must be padded to the same length. - And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). - - For example: - batch_size=36, num_reqs_pad_size=2, num_reqs=16 - By default, each request should have inference 2 token, which means actual_seq_lengths_q should be - [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. - - However, mtp torchair + PD scenario, the actual_seq_lengths_q may be - [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. - In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. - after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] - """ - FIA_SEQ_LEN_LIMIT = 16 - need_padding = num_reqs_pad_size != 0 and \ - len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ - common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT - if need_padding: - padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] - start_val = actual_seq_lengths_q[-1] - end_val = padding_seq_len_q[-1] - - num_step = len(padding_seq_len_q) - interpolated = np.round( - np.linspace(start_val, end_val, - num_step + 1)[1:]).astype(int).tolist() - assert interpolated[-1] == end_val - assert len(interpolated) == len(padding_seq_len_q) - actual_seq_lengths_q = actual_seq_lengths_q + interpolated - else: - actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ - num_reqs:num_reqs + num_reqs_pad_size] - - # return actual_seq_lengths_q - return torch.Tensor(actual_seq_lengths_q).to(torch.int32).npu() - - -class PrefillSFAPreprocessResult(NamedTuple): - q_nope: Optional[torch.Tensor] = None - q_pe: Optional[torch.Tensor] = None - k_nope: Optional[torch.Tensor] = None - k_pe: Optional[torch.Tensor] = None - topk_indices: Optional[torch.Tensor] = None - query_states: Optional[torch.Tensor] = None - key_states: Optional[torch.Tensor] = None - - -class DecodeSFAPreprocessResult(NamedTuple): - q_nope: Optional[torch.Tensor] = None - q_pe: Optional[torch.Tensor] = None - # nope_cache: Optional[torch.Tensor] = None - # rope_cache: Optional[torch.Tensor] = None - topk_indices: Optional[torch.Tensor] = None - query_states: Optional[torch.Tensor] = None - key_states: Optional[torch.Tensor] = None - bsz: Optional[int] = None - - -class AscendSFATorchairImpl(MLAAttentionImpl): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - **kwargs, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - # MLA Args - self.q_lora_rank = kwargs['q_lora_rank'] - self.kv_lora_rank = kwargs['kv_lora_rank'] - self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] - self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] - self.qk_head_dim = kwargs['qk_head_dim'] - self.v_head_dim = kwargs['v_head_dim'] - self.rotary_emb = kwargs['rotary_emb'] - self.q_proj = kwargs['q_proj'] - self.kv_b_proj = kwargs['kv_b_proj'] - self.o_proj = kwargs['o_proj'] - self.indexer = kwargs['indexer'] - self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) - self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.q_a_proj = kwargs.get('q_a_proj', None) - self.q_a_layernorm = kwargs.get('q_a_layernorm', None) - self.decoder_layer = kwargs.get('decoder_layer', None) - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_rank = self.num_heads // self.tp_size - if self.q_a_proj is not None: - self.q_b_proj = self.q_proj - else: - self.q_b_proj = None - - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.weight_prefetch_config.enabled - self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - if ascend_config.torchair_graph_config.enabled: - self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[ - 0] - self.actual_seq_length = torch.arange(1, self.graph_batch_size + - 1).to(torch.int32).npu() - vllm_config = get_current_vllm_config() - self.ring_mla_mask_size = 512 - self.prefill_mask = None - - # indexer param - self.dim = self.indexer.dim - self.n_heads: int = self.indexer.n_heads # 64 - self.head_dim: int = self.indexer.head_dim # 128 - self.index_topk: int = self.indexer.index_topk # 2048 - self.wq_b = self.indexer.wq_b - self.wk = self.indexer.wk - self.weights_proj = self.indexer.weights_proj - self.k_norm = self.indexer.k_norm - self.softmax_scale = self.indexer.softmax_scale - - # Adapt torch air graph mode with spec decoding. - speculative_config = vllm_config.speculative_config - if speculative_config is not None: - self.spec_token_num = speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 - - self.cp_size = 1 - - if self.q_a_proj is not None: - self.prefix = self.q_a_proj.prefix - else: - self.prefix = 0 - self.debug_layer_idx = int(self.prefix.split(".")[2]) - self.layers = vllm_config.model_config.hf_config.num_hidden_layers - self.first_k_dense_replace = vllm_config.model_config.hf_config.first_k_dense_replace - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) - if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO: - self._process_weights_for_fused_mlapo(act_dtype) - - def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): - kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data.clone() - kv_a_proj_wt = kv_a_proj_wt.t().contiguous() - kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) - kv_a_proj_wt = kv_a_proj_wt.t().contiguous() - wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()), - dim=-1) - - wd_qkv = wd_qkv.t().contiguous() - wd_qkv = transdata(wd_qkv, - block_size=(16, 32)).unsqueeze(0).contiguous() - if is_enable_nz(): - self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) - - kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() - kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( - self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() - kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, - self.qk_rope_head_dim) - kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( - self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.deq_scale_qkv = torch.cat( - (kv_a_proj_deq_scl, self.q_a_proj.deq_scale.clone()), - dim=-1).contiguous() - - kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias.clone() - kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( - self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() - kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, - self.qk_rope_head_dim) - kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( - self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.quant_bias_qkv = torch.cat( - (kv_a_proj_qt_bias, self.q_a_proj.quant_bias.clone()), - dim=-1).contiguous() - - wu_q = self.q_proj.weight.data.clone() - wu_q = wu_q.t().reshape(self.num_heads, - self.qk_nope_head_dim + self.qk_rope_head_dim, - -1) - wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) - wu_q = wu_q.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), - -1) - wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() - if is_enable_nz(): - self.wu_q = torch_npu.npu_format_cast(wu_q, 29) - - qb_deq_scl = self.q_proj.deq_scale.data.clone() - qb_deq_scl = qb_deq_scl.reshape( - self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) - qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) - self.qb_deq_scl = qb_deq_scl.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) - - qb_qt_bias = self.q_proj.quant_bias.data.clone() - qb_qt_bias = qb_qt_bias.reshape( - self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) - qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) - self.qb_qt_bias = qb_qt_bias.reshape( - self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) - - self.gamma0 = self.decoder_layer.input_layernorm.weight.data - self.beta0 = self.decoder_layer.input_layernorm.bias.data - self.gamma1 = self.q_a_layernorm.weight.data - self.beta1 = self.q_a_layernorm.bias.data - self.gamma2 = self.kv_a_layernorm.weight.data - self.quant_scale0 = self.q_a_proj.input_scale.data - self.quant_offset0 = self.q_a_proj.input_offset.data - self.quant_scale1 = self.q_proj.input_scale.data - self.quant_offset1 = self.q_proj.input_offset.data - - def _sfa_decode_preprocess(self, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv): - bsz = hidden_states.shape[0] - cos_shape = attn_metadata.decode.cos.shape - cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) - sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) - ctkv_scale = torch.tensor([1], - dtype=hidden_states.dtype, - device=hidden_states.device) - q_nope_scale = torch.tensor([1], - dtype=hidden_states.dtype, - device=hidden_states.device) - - decode_q_nope, _, decode_q_pe, _ = torch_npu.npu_mla_process( - hidden_states, - self.gamma0, - self.beta0, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - cos, - sin, - self.kv_b_proj_w_k, - kv_cache[0], - kv_cache[1], - attn_metadata.slot_mapping.flatten(), - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=ctkv_scale, - q_nope_scale=q_nope_scale, - cache_mode_opt="krope_ctkv", - quant_mode_opt="per_tensor_quant_asymm", - ) - decode_k_nope = kv_cache[0] - decode_k_pe = kv_cache[1] - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, - self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) - - hidden_states = self.decoder_layer.input_layernorm(hidden_states) - - decode_kq = self.q_a_proj(hidden_states) # q down - decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm - - topk_indices = self.indexer_select(hidden_states, - decode_q_c, - attn_metadata=attn_metadata, - kv_cache=kv_cache, - is_prefill=False) - query_states = (decode_q_nope, decode_q_pe) - key_states = (decode_k_nope, decode_k_pe) - decode_preprocess_res = DecodeSFAPreprocessResult( - q_nope=decode_q_nope, - q_pe=decode_q_pe, - topk_indices=topk_indices, - query_states=query_states, - key_states=key_states, - bsz=bsz, - ) - return decode_preprocess_res - - def forward( - self, - hidden_states: torch.Tensor, # query in unified attn - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, - need_gather_q_kv: bool = False, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: - # Profiling run. - return output.fill_(0) - - if attn_metadata.prefill is not None: - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None - - bsz = 1 - - hidden_states_prefill = hidden_states - prefill_slot_mapping = attn_metadata.slot_mapping - - prefill_kq = self.q_a_proj(hidden_states_prefill) # q down - prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm - prefill_kv_no_split = self.kv_a_proj_with_mqa( - hidden_states_prefill) # c_kv - - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - prefill_kv_no_split = get_tp_group().all_gather( - prefill_kv_no_split, - 0)[attn_metadata.num_decode_tokens:attn_metadata. - num_actual_tokens] - # prefill_q_c = q_c[ - # num_decode_tokens:num_actual_tokens] - - # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] - - # prefill_kv_no_split = kv_no_split[ - # num_decode_tokens:num_actual_tokens] - # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] - prefill_qr = prefill_q_c - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - prefill_qr = get_tp_group().all_gather( - prefill_qr, - 0)[attn_metadata.num_decode_tokens:attn_metadata. - num_actual_tokens] - - prefill_q = self.q_b_proj(prefill_qr) - prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) - prefill_q_nope, prefill_q_pe = torch.split( - prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) - prefill_q_nope = prefill_q_nope.view( - -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - prefill_q_nope = (torch.matmul(prefill_q_nope, - self.kv_b_proj_w_k).transpose( - 1, - 0).view(-1, self.num_heads, - self.kv_lora_rank)) - prefill_q_pe = prefill_q_pe.unsqueeze(2) - - # stream2 kv - - nope_cache = kv_cache[0] - rope_cache = kv_cache[1] - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - cos_q, sin_q = cos, sin - - prefill_q_pe = torch_npu.npu_interleave_rope( - prefill_q_pe, cos_q, sin_q) # BNSD - prefill_q_pe = prefill_q_pe.squeeze(2) #BSH - # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? - - prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) - prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - prefill_latent_cache.view( - -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), - self.kv_a_layernorm.weight, - cos.view(-1, 1, 1, self.qk_rope_head_dim), - sin.view(-1, 1, 1, self.qk_rope_head_dim), - prefill_slot_mapping.to(torch.int64), - rope_cache, - nope_cache, - k_rope_scale=None, - c_kv_scale=None, - k_rope_offset=None, - c_kv_offset=None, - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode="PA") - - topk_indices = self.indexer_select(x=hidden_states_prefill, - qr=prefill_qr, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - is_prefill=True) - query_states = (prefill_q_nope, prefill_q_pe) - key_states = (prefill_k_nope, prefill_k_pe) - q_nope, q_pe = query_states - k_nope, k_rope = key_states - prefill_metadata = attn_metadata.prefill - - slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( - query=q_nope, - key=k_nope, - value=k_nope, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=prefill_metadata.block_table, - actual_seq_lengths_query=prefill_metadata.query_lens, - actual_seq_lengths_kv=prefill_metadata.seq_lens, - query_rope=q_pe, - key_rope=k_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, - ) - slc_fa_fusion = slc_fa_fusion.transpose(0, 1) - - # input shape [N//attn_tp_size, T(bs*q_len), D] - # output shape [T(bs*q_len), N//attn_tp_size, D] - attn_output = torch.matmul( - slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( - -1, self.num_heads * self.v_head_dim) - # o_proj_input[num_decode_tokens:] = attn_output - output[...] = self.o_proj(attn_output, is_force_scatter=True) - return output - - elif attn_metadata.decode is not None: - if envs_ascend.VLLM_ASCEND_ENABLE_MLAPO: - prep_res = self._sfa_decode_preprocess(hidden_states, kv_cache, - attn_metadata, - need_gather_q_kv) - q_nope, q_pe = prep_res.query_states - k_nope, k_rope = prep_res.key_states - topk_indices = prep_res.topk_indices - else: - q_len = 1 - hidden_states_decode = hidden_states - - decode_kq = self.q_a_proj(hidden_states_decode) # q down - decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm - decode_kv_no_split = self.kv_a_proj_with_mqa( - hidden_states_decode) # c_kv - # self.actual_seq_length = torch.arange(1,self.graph_batch_size+1).to(torch.int32).npu() - - # decode_q_c = q_c[:num_decode_tokens] - decode_slot_mapping = attn_metadata.slot_mapping - - decode_q = self.q_b_proj(decode_q_c) - bsz, _ = decode_q.shape - decode_q = decode_q.view(bsz, self.num_heads, 1, - self.qk_head_dim) # [16, 16, 1, 192] - decode_q_nope, decode_q_pe = torch.split( - decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) # [..., 128/64] - decode_q_nope = decode_q_nope.view( - -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - decode_q_nope = (torch.matmul( - decode_q_nope, self.kv_b_proj_w_k).transpose(1, 0).view( - bsz, q_len, self.num_heads, self.kv_lora_rank)) - - # stream2 kv - key_cache = kv_cache[0] - value_cache = kv_cache[1] - cos = attn_metadata.decode.cos # [16, 1, 1, 64] - sin = attn_metadata.decode.sin - cos_q, sin_q = cos, sin - cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - - decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze( - 1) - decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - decode_kv_no_split, - self.kv_a_layernorm.weight, - cos, - sin, - decode_slot_mapping.to(torch.int64), - value_cache, - key_cache, - c_kv_scale=None, - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode='PA') # adapter NZ - # nz_block_size = 16 - # KVCACHE_NZ_DIM = 16 - # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) - # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) - decode_q_pe = torch_npu.npu_interleave_rope( - decode_q_pe, cos, sin) # BNSD - - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, - self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) - - topk_indices = self.indexer_select(hidden_states_decode, - decode_q_c, - attn_metadata=attn_metadata, - kv_cache=kv_cache, - is_prefill=False) - - query_states = (decode_q_nope, decode_q_pe) - key_states = (decode_k_nope, decode_k_rope) - q_nope, q_pe = query_states - k_nope, k_rope = key_states - - decode_metadata = attn_metadata.decode - slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( - query=q_nope, - key=k_nope, - value=k_nope, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=attn_metadata.decode.block_table, - actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=decode_metadata.seq_lens, - query_rope=q_pe, - key_rope=k_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, - ) - slc_fa_fusion = slc_fa_fusion.squeeze(1) - slc_fa_fusion = slc_fa_fusion.transpose(0, 1) - - # input shape [N//attn_tp_size, T(bs*q_len), D] - # output shape [T(bs*q_len), N//attn_tp_size, D] - attn_output = torch.matmul( - slc_fa_fusion, self.kv_b_proj_w_v).transpose(1, 0).reshape( - -1, self.num_heads * self.v_head_dim) - output[...] = self.o_proj(attn_output) - return output - - def mla_epilog(self, - attn_output: torch.Tensor = None, - absorb: bool = False): - # TODO: - attn_output = self.o_proj(attn_output) - return attn_output - - def indexer_select( - self, - x: torch.Tensor, - qr: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, - is_prefill: bool = True, - ): - if attn_metadata.prefill is not None: - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - elif attn_metadata.decode is not None: - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin - - cos_q, sin_q = cos, sin - cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - - # q process in new stream - q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] - q_pe, q_nope = torch.split( - q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64,64+64] - - q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) - q_pe = q_pe.squeeze(2) - q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - - k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] - if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - k_proj = get_tp_group().all_gather( - k_proj, 0)[attn_metadata.num_decode_tokens:attn_metadata. - num_actual_tokens] - k = self.k_norm(k_proj).unsqueeze(1) - k_pe, k_nope = torch.split( - k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64+64] - - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) - - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), - attn_metadata.slot_mapping.view( - -1, 1), - k.view(-1, - k.shape[-1])) # b, s, n, d - - weights = self.weights_proj(x) - if self.enable_shared_expert_dp and is_prefill and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - weights = get_tp_group().all_gather( - weights, 0)[attn_metadata.num_decode_tokens:attn_metadata. - num_actual_tokens] - actual_seq_lengths_query = None - actual_seq_lengths_key = None - block_table = None - if attn_metadata.prefill is not None: - actual_seq_lengths_query = attn_metadata.prefill.query_lens - actual_seq_lengths_key = attn_metadata.prefill.seq_lens - - block_table = attn_metadata.prefill.block_table - elif attn_metadata.decode is not None: - actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q - actual_seq_lengths_key = attn_metadata.decode.seq_lens.to( - torch.int32) - - block_table = attn_metadata.decode.block_table - - topk_indices = torch.ops.custom.npu_lightning_indexer( - query=q, - key=kv_cache[2], - weights=weights, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, - layout_query="TND", - layout_key="PA_BSND", - sparse_count=2048, - sparse_mode=3) - return topk_indices - - -def round_up(val: int, align: int) -> int: - if align == 0: - return 0 - return -(val // -align) * align - - -def trans_rope_weight(weight, rope_dim): - weight_1 = weight[..., -rope_dim::2, :].contiguous() - weight_2 = weight[..., -rope_dim + 1::2, :].contiguous() - weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2) - - return weight.contiguous() - - -def transdata(nd_mat, block_size: tuple = (16, 16)): - r = round_up(nd_mat.shape[0], block_size[0]) - c = round_up(nd_mat.shape[1], block_size[1]) - r_pad = r - nd_mat.shape[0] - c_pad = c - nd_mat.shape[1] - nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad))) - nz_mat = torch.permute( - torch.reshape( - nd_mat, - (r // block_size[0], block_size[0], c // block_size[1], - block_size[1]), - ), - [2, 0, 1, 3], - ) - nz_mat = torch.reshape( - nz_mat, - (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) - return nz_mat diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py deleted file mode 100644 index dbee8003..00000000 --- a/vllm_ascend/torchair/torchair_worker.py +++ /dev/null @@ -1,63 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from vllm.logger import logger - -import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner -from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist, - delete_torchair_cache_file, - read_kv_cache_bytes_from_file) -from vllm_ascend.worker.worker_v1 import NPUWorker - - -class NPUTorchairWorker(NPUWorker): - """Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class.""" - - def determine_available_memory(self) -> int: - """Override determine_available_memory to use cached torchair kv_cache_bytes.""" - - available_kv_cache_memory = super().determine_available_memory() - ascend_config = get_ascend_config() - if ascend_config.enable_shared_expert_dp: - return available_kv_cache_memory - if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: - if check_kv_cache_bytes_cache_exist(): - old_kv_cache_bytes = read_kv_cache_bytes_from_file( - torch.distributed.get_rank()) - if 0 < old_kv_cache_bytes <= available_kv_cache_memory: - logger.info( - f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" - ) - self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes - return old_kv_cache_bytes - else: - logger.info( - "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" - ) - delete_torchair_cache_file() - bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE - available_kv_cache_memory -= bytes_floating_tolerance - logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") - self.model_runner.new_kv_cache_bytes = available_kv_cache_memory - return available_kv_cache_memory - - def init_device(self): - """Override init_device to init torchair model runner""" - device = self._init_device() - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = NPUTorchairModelRunner(self.vllm_config, device) diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py deleted file mode 100644 index 19367038..00000000 --- a/vllm_ascend/torchair/utils.py +++ /dev/null @@ -1,275 +0,0 @@ -import fcntl -import os -import shutil -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass - -import torch -import torch_npu -from torchair.scope import super_kernel as _super_kernel - -try: - # Recent release of torchair has moved these ops to `.scope`. - from torchair.scope import npu_stream_switch as _npu_stream_switch - from torchair.scope import npu_wait_tensor as _npu_wait_tensor -except ImportError: - from torchair.ops import NpuStreamSwitch as _npu_stream_switch - from torchair.ops import npu_wait_tensor as _npu_wait_tensor - -import vllm_ascend.envs as envs_ascend -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz - -KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes" -KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes" -TORCHAIR_CACHE_PATH_NAME = ".torchair_cache" -TORCHAIR_CACHE_DIR = os.path.join( - os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME) - - -@dataclass -class TorchairCommonAttentionMetadata: - """ - Per-batch attention metadata, shared across layers and backends. - AttentionMetadataBuilder instances use it to construct per-layer metadata. - - For many of the tensors we keep both GPU and CPU versions. - """ - - num_reqs: int - """Number of requests""" - - num_actual_tokens: int - """Total number of tokens in batch""" - - decode_token_per_req: int - - actual_seq_lengths_q: list[int] - - attn_mask: torch.Tensor = None - - spec_attn_mask: torch.Tensor = None - - graph_pad_size: int = -1 - - -@contextmanager -def _file_lock(file_descriptor, lock_type): - fcntl.flock(file_descriptor, lock_type) - try: - yield - finally: - fcntl.flock(file_descriptor, fcntl.LOCK_UN) - - -def _get_torchair_current_work_dir(file_name=None): - if file_name is None: - return TORCHAIR_CACHE_DIR - return os.path.join(TORCHAIR_CACHE_DIR, file_name) - - -def check_torchair_cache_exist(): - res = False - torch_air_abs_path = _get_torchair_current_work_dir() - if os.path.exists(torch_air_abs_path): - file_list = os.listdir(torch_air_abs_path) - if len(file_list) != 0: - res = True - return res - - -def check_kv_cache_bytes_cache_exist(): - res = False - kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir( - KV_CACHE_BYTES_CACHE_PATH_NAME) - if os.path.exists(kv_cache_bytes_cache_abs_path): - file_list = os.listdir(kv_cache_bytes_cache_abs_path) - if len(file_list) != 0: - res = True - return res - - -def read_kv_cache_bytes_from_file(rank) -> int: - kv_cache_bytes = -1 - kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir( - KV_CACHE_BYTES_CACHE_PATH_NAME) - kv_cache_bytes_file = os.path.join( - kv_cache_bytes_cache_abs_path, - f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}") - with open(kv_cache_bytes_file, "r", encoding="utf-8") as f: - with _file_lock(f, fcntl.LOCK_SH): - kv_cache_bytes = int(f.readline()) - return kv_cache_bytes - - -def write_kv_cache_bytes_to_file(rank, kv_cache_bytes): - kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir( - KV_CACHE_BYTES_CACHE_PATH_NAME) - os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True) - kv_cache_bytes_file = os.path.join( - kv_cache_bytes_cache_abs_path, - f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}") - with open(kv_cache_bytes_file, "w", encoding="utf-8") as f: - with _file_lock(f, fcntl.LOCK_EX): - f.write(f"{kv_cache_bytes}") - - -def delete_torchair_cache_file(): - torch_air_abs_path = _get_torchair_current_work_dir() - try: - shutil.rmtree(torch_air_abs_path) - except FileNotFoundError: - pass - - -def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True): - return _npu_stream_switch(tag, priority) if enabled else nullcontext() - - -def npu_wait_tensor(self: torch.Tensor, - dependency: torch.Tensor, - *, - enabled: bool = True): - return _npu_wait_tensor(self, dependency) if enabled else self - - -def converting_weight_acl_format(model, format): - # currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ - # in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ - # is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this - # conversion when using torchair graph mode on 300I Duo platform. - # TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant - # accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode. - from vllm.model_executor.layers.fused_moe.layer import FusedMoE - - for module in model.modules(): - if isinstance(module, FusedMoE): - if torch_npu.get_npu_format(module.w13_weight.data) == format: - return - if format == ACL_FORMAT_FRACTAL_NZ \ - and not is_enable_nz(): - return - module.w13_weight.data = torch_npu.npu_format_cast( - module.w13_weight.data, format) - module.w2_weight.data = torch_npu.npu_format_cast( - module.w2_weight.data, format) - - -def register_torchair_model(): - from vllm import ModelRegistry - - ModelRegistry.register_model( - "DeepSeekMTPModel", - "vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP" - ) - - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM" - ) - - ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" - ) - - ModelRegistry.register_model( - "DeepseekV32ForCausalLM", - "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" - ) - - ModelRegistry.register_model( - "Qwen2ForCausalLM", - "vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM") - - ModelRegistry.register_model( - "Qwen3MoeForCausalLM", - "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM") - - ModelRegistry.register_model( - "PanguProMoEForCausalLM", - "vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM" - ) - - -def torchair_quant_method_register(): - from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP - from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( - TorchairAscendW4A8DynamicFusedMoEMethod, - TorchairAscendW4A8DynamicLinearMethod) - from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( - TorchairAscendW8A8DynamicFusedMoEMethod, - TorchairAscendW8A8DynamicLinearMethod) - - ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ - "linear"] = TorchairAscendW8A8DynamicLinearMethod - ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ - "moe"] = TorchairAscendW8A8DynamicFusedMoEMethod - ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ - "linear"] = TorchairAscendW4A8DynamicLinearMethod - ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ - "moe"] = TorchairAscendW4A8DynamicFusedMoEMethod - - -def torchair_ops_patch(): - from vllm_ascend.ops.activation import AscendSiluAndMul - from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm - from vllm_ascend.ops.rotary_embedding import ( - AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) - from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding - from vllm_ascend.torchair.ops import (torchair_activation, - torchair_layernorm) - from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( - deepseek_rope_init_func, native_rope_deepseek_forward, - qwen_rope_init_func, rope_forward) - from vllm_ascend.torchair.ops.torchair_vocab_parallel_embedding import \ - vocab_embedding_forward - - AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign] - AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign] - - AscendDeepseekScalingRotaryEmbedding.__init__ = deepseek_rope_init_func # type: ignore[method-assign] - AscendDeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward # type: ignore[method-assign] - - AscendRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign] - AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] - - AscendQuantRMSNorm.__init__ = torchair_layernorm.torchair_rmsnorm_init_ # type: ignore[method-assign] - AscendQuantRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] - - AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] - AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign] - - -def super_kernel(prefix: str, option: str, enabled: bool = True): - return _super_kernel(prefix, option) if enabled else nullcontext() - - -# TODO(ttanzhiqiang): rm_router_logits -# dp>1 will trigger -# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. -def get_rm_router_logits_state(ep_size: int, dp_size: int, - is_deepseek_v3_r1: bool): - # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep - # only supports deepseek v3/r1 - if dp_size > 1: - if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 - and is_deepseek_v3_r1): - return True - elif ep_size == 1 and is_deepseek_v3_r1: - return True - return False - - -# TODO(ttanzhiqiang): all_reduce merge -# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce -# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. -def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): - # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep - # only supports deepseek v3/r1 - if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 - and is_deepseek_v3_r1): - return True - elif ep_size == 1 and is_deepseek_v3_r1: - return True - return False diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f32940a2..8d57417c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -143,7 +143,6 @@ from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_enable_nz, @@ -638,7 +637,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): # Set up speculative decoding. self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - TorchairMtpProposer, SuffixDecodingProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 @@ -2917,8 +2915,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): return attn_metadata - def _generate_dummy_run_hidden_states(self, with_prefill, - is_torchair_compile, input_ids, + def _generate_dummy_run_hidden_states(self, with_prefill, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds): hidden_states = self.model(input_ids=input_ids, @@ -2960,7 +2957,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self, num_tokens: int, with_prefill: bool = False, - is_torchair_compile: bool = False, aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, @@ -3136,9 +3132,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method): hidden_states = self._generate_dummy_run_hidden_states( - with_prefill, is_torchair_compile, input_ids, positions, - attn_metadata, num_tokens_padded, intermediate_tensors, - inputs_embeds) + with_prefill, input_ids, positions, attn_metadata, + num_tokens_padded, intermediate_tensors, inputs_embeds) dummy_compute_logits(hidden_states) if self.drafter: @@ -4262,9 +4257,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): return list(model.pooler.get_supported_tasks()) - def _build_drafter_prepare_inputs_torchair_param(self): - return False - def _update_tokens_for_pcp(self, tokens): num_reqs = self.input_batch.num_reqs self.num_pcp_pads = self.num_pcp_pads[:num_reqs]