2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
import gc
|
2025-02-05 10:53:12 +08:00
|
|
|
import os
|
2025-06-09 14:08:18 +08:00
|
|
|
from datetime import timedelta
|
2025-02-21 17:07:37 +08:00
|
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-03-20 19:34:44 +08:00
|
|
|
import vllm.envs as envs
|
2025-06-09 14:08:18 +08:00
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.distributed.distributed_c10d import PrefixStore
|
2025-04-15 10:18:05 +08:00
|
|
|
from vllm.logger import logger
|
2025-04-18 08:56:05 +08:00
|
|
|
from vllm.platforms import Platform, PlatformEnum
|
2025-04-03 14:52:34 +08:00
|
|
|
|
2025-07-03 22:21:42 +08:00
|
|
|
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
|
|
|
|
init_ascend_config)
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
from vllm_ascend.utils import (ASCEND_QUATIZATION_METHOD, is_310p,
|
|
|
|
|
update_aclgraph_sizes)
|
2025-05-12 20:26:22 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
if TYPE_CHECKING:
|
2025-03-28 19:34:23 +08:00
|
|
|
from vllm.config import ModelConfig, VllmConfig
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
else:
|
2025-03-28 19:34:23 +08:00
|
|
|
ModelConfig = None
|
|
|
|
|
VllmConfig = None
|
2025-02-21 17:07:37 +08:00
|
|
|
FlexibleArgumentParser = None
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
class NPUPlatform(Platform):
|
|
|
|
|
|
|
|
|
|
_enum = PlatformEnum.OOT
|
|
|
|
|
device_name: str = "npu"
|
|
|
|
|
device_type: str = "npu"
|
2025-03-21 15:55:51 +08:00
|
|
|
simple_compile_backend: str = "eager" # Disable torch.compile()
|
2025-02-05 10:53:12 +08:00
|
|
|
ray_device_key: str = "NPU"
|
|
|
|
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
2025-02-21 17:10:30 +08:00
|
|
|
dispatch_key: str = "PrivateUse1"
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
supported_quantization: list[str] = [ASCEND_QUATIZATION_METHOD]
|
2025-02-21 17:07:37 +08:00
|
|
|
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
def is_sleep_mode_available(self) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
@classmethod
|
|
|
|
|
def pre_register_and_update(cls,
|
|
|
|
|
parser: Optional[FlexibleArgumentParser] = None
|
|
|
|
|
) -> None:
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
# Adapt the global patch here.
|
|
|
|
|
from vllm_ascend.utils import adapt_patch
|
|
|
|
|
adapt_patch(is_global_patch=True)
|
|
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
# For online serving, "ascend" quantization method is not a choice natively,
|
|
|
|
|
# so we need to add "ascend" quantization method to quantization methods list
|
|
|
|
|
# and the user can enable quantization using "vllm serve --quantization ascend".
|
|
|
|
|
if parser is not None:
|
|
|
|
|
quant_action = parser._option_string_actions.get('--quantization')
|
[Bugfix] Add verification for `quant_action.choices` to avoid `TypeError` (#1046)
### What this PR does / why we need it?
When I run vllm-ascend, I get this error msg:
```bash
Traceback (most recent call last):
File "/home/sss/software/miniconda3/envs/vllm-v1/bin/vllm", line 8, in <module>
sys.exit(main())
File "/home/sss/github/vllm-project/vllm/vllm/entrypoints/cli/main.py", line 50, in main
cmd.subparser_init(subparsers).set_defaults(
File "/home/sss/github/vllm-project/vllm/vllm/entrypoints/cli/serve.py", line 101, in subparser_init
serve_parser = make_arg_parser(serve_parser)
File "/home/sss/github/vllm-project/vllm/vllm/entrypoints/openai/cli_args.py", line 254, in make_arg_parser
parser = AsyncEngineArgs.add_cli_args(parser)
File "/home/sss/github/vllm-project/vllm/vllm/engine/arg_utils.py", line 1582, in add_cli_args
current_platform.pre_register_and_update(parser)
File "/home/sss/github/vllm-project/vllm-ascend/vllm_ascend/platform.py", line 80, in pre_register_and_update
if ASCEND_QUATIZATION_METHOD not in quant_action.choices:
TypeError: argument of type 'NoneType' is not iterable
[ERROR] 2025-06-03-02:53:42 (PID:6005, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception
```
This is because the `choices` attribute in `quant_action` can be `None`
and we don't check it.
```bash
# quant_action
_StoreAction(option_strings=['--quantization', '-q'], dest='quantization', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, required=False, help='Method used to quantize the weights. If `None`, we first check the\n`quantization_config` attribute in the model config file. If that is\n`None`, we assume the model weights are not quantized and use `dtype` to\ndetermine the data type of the weights.', metavar=None)
```
Thus, I have added check for the `choices` to handle the scenario of
`choices=None`.
### Does this PR introduce _any_ user-facing change?
yes, vllm server with ascend quantization works now.
### How was this patch tested?
by `vllm server --quantization ascend` command.
Related: https://github.com/vllm-project/vllm/issues/19004
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-06-03 11:44:45 +08:00
|
|
|
if quant_action and hasattr(quant_action,
|
|
|
|
|
'choices') and quant_action.choices:
|
2025-05-17 17:36:04 +08:00
|
|
|
if ASCEND_QUATIZATION_METHOD not in quant_action.choices:
|
|
|
|
|
quant_action.choices.append(ASCEND_QUATIZATION_METHOD)
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm_ascend.quantization.quant_config import \
|
|
|
|
|
AscendQuantConfig # noqa: F401
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_device_capability(cls, device_id: int = 0):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
2025-03-20 19:34:44 +08:00
|
|
|
return torch.npu.get_device_name(device_id)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def inference_mode(cls):
|
|
|
|
|
return torch.inference_mode()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def set_device(cls, device: torch.device):
|
|
|
|
|
torch.npu.set_device(device)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def empty_cache(cls):
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def synchronize(cls):
|
|
|
|
|
torch.npu.synchronize()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
|
|
|
|
return torch.npu.mem_get_info()
|
|
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
@classmethod
|
|
|
|
|
def clear_npu_memory(cls):
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
torch.npu.reset_peak_memory_stats()
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
2025-06-05 16:28:01 +08:00
|
|
|
# initialize ascend config from vllm additional_config
|
|
|
|
|
ascend_config = init_ascend_config(vllm_config)
|
|
|
|
|
|
2025-04-03 16:03:08 +08:00
|
|
|
from vllm.config import CompilationLevel # noqa: E402
|
2025-03-20 19:34:44 +08:00
|
|
|
compilation_config = vllm_config.compilation_config
|
2025-05-29 11:58:26 +08:00
|
|
|
model_config = vllm_config.model_config
|
2025-05-30 15:17:11 +08:00
|
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
|
cache_config = vllm_config.cache_config
|
2025-06-28 18:51:07 +08:00
|
|
|
kv_cache_dtype = vllm_config.additional_config.get(
|
|
|
|
|
"kv_cache_dtype", None)
|
|
|
|
|
if kv_cache_dtype is not None:
|
|
|
|
|
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
2025-05-30 15:17:11 +08:00
|
|
|
|
|
|
|
|
if parallel_config:
|
|
|
|
|
# Default value for expert tensor parallel size
|
|
|
|
|
parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size
|
|
|
|
|
|
|
|
|
|
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
|
|
|
|
|
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
|
2025-06-07 21:11:36 +08:00
|
|
|
if parallel_config.enable_expert_parallel:
|
|
|
|
|
parallel_config.expert_tensor_parallel_size = 1
|
|
|
|
|
# NOTE: When enable_expert_parallel is False and param `asceend_config.expert_tensor_parallel_size`
|
|
|
|
|
# is configured, use ascend_config
|
|
|
|
|
elif ascend_config.expert_tensor_parallel_size > 0:
|
2025-06-05 16:28:01 +08:00
|
|
|
parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size
|
2025-05-30 15:17:11 +08:00
|
|
|
|
|
|
|
|
# Calculate expert parallel size based on world size
|
|
|
|
|
parallel_config.expert_parallel_size = (
|
2025-06-04 18:31:41 +08:00
|
|
|
parallel_config.world_size_across_dp //
|
2025-05-30 15:17:11 +08:00
|
|
|
parallel_config.expert_tensor_parallel_size)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
|
2025-05-29 11:58:26 +08:00
|
|
|
if model_config is None:
|
2025-04-24 17:20:11 +08:00
|
|
|
logger.warning("Model config is missing. This may indicate "
|
|
|
|
|
"that we are running a test case")
|
|
|
|
|
enforce_eager = False
|
|
|
|
|
else:
|
2025-05-29 11:58:26 +08:00
|
|
|
enforce_eager = getattr(model_config, "enforce_eager", False)
|
|
|
|
|
|
2025-06-05 16:28:01 +08:00
|
|
|
check_ascend_config(vllm_config, enforce_eager)
|
2025-04-24 17:20:11 +08:00
|
|
|
|
|
|
|
|
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
|
|
|
|
|
logger.info("Compilation disabled, using eager mode by default")
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
|
|
|
|
elif compilation_config.level != CompilationLevel.PIECEWISE:
|
2025-03-20 19:34:44 +08:00
|
|
|
logger.warning(
|
2025-04-24 17:20:11 +08:00
|
|
|
"NPU does not support %s compilation level. Setting level to NO_COMPILATION",
|
2025-03-20 19:34:44 +08:00
|
|
|
compilation_config.level)
|
|
|
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
2025-06-05 16:28:01 +08:00
|
|
|
elif ascend_config.torchair_graph_config.enabled:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Torchair compilation enabled on NPU. Setting level to NO_COMPILATION"
|
|
|
|
|
)
|
|
|
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
else:
|
|
|
|
|
logger.info(
|
2025-04-24 17:20:11 +08:00
|
|
|
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
|
|
|
|
"using only ACL Graph mode")
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
compilation_config.use_inductor = False
|
|
|
|
|
compilation_config.splitting_ops.extend(
|
|
|
|
|
["vllm.unified_ascend_attention_with_output"])
|
2025-05-12 20:26:22 +08:00
|
|
|
update_aclgraph_sizes(vllm_config)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-03-28 16:31:27 +08:00
|
|
|
if parallel_config and parallel_config.worker_cls == "auto":
|
2025-03-20 19:34:44 +08:00
|
|
|
if envs.VLLM_USE_V1:
|
|
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
2025-03-21 15:55:51 +08:00
|
|
|
elif vllm_config.speculative_config:
|
[Bugfix][Spec Decode] Enable `ACL_OP_INIT_MODE=1` directly only when using V0 spec decode (#1258)
### What this PR does / why we need it?
Enable `ACL_OP_INIT_MODE=1` directly only when using V0 spec decode.
Find more details at **mengwei805**'s comment in
https://github.com/vllm-project/vllm-ascend/pull/1123.
### Does this PR introduce _any_ user-facing change?
The user will not be aware of `VLLM_ASCEND_ACL_OP_INIT_MODE`
(`ACL_OP_INIT_MODE`).
### How was this patch tested?
Test scripts:
```python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="Qwen/Qwen2.5-1.5B-Instruct",
tensor_parallel_size=1,
speculative_config={
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 4,
},
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Results:
```
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 76.70it/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.33it/s, est. speed input: 6.64 toks/s, output: 21.26 toks/s]
Prompt: 'The future of AI is', Generated text: ' bright\n\n04/15/2020\n\nBy: James'
```
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2025-06-18 17:50:20 +08:00
|
|
|
# NOTE: We set this var to `1` in vllm-ascend to avoid segment
|
|
|
|
|
# fault when using spec decode with V0 engine.
|
|
|
|
|
os.environ["ACL_OP_INIT_MODE"] = "1"
|
2025-03-21 15:55:51 +08:00
|
|
|
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
|
|
|
|
parallel_config.sd_worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
|
|
|
|
elif vllm_config.scheduler_config.is_multi_step:
|
|
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
2025-03-11 19:20:06 +08:00
|
|
|
else:
|
2025-03-21 15:55:51 +08:00
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-03-28 19:34:23 +08:00
|
|
|
if cache_config:
|
|
|
|
|
if cache_config.block_size is None:
|
|
|
|
|
cache_config.block_size = 128
|
2025-05-09 16:39:28 +08:00
|
|
|
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
2025-03-28 19:34:23 +08:00
|
|
|
logger.warning(
|
2025-05-09 16:39:28 +08:00
|
|
|
"If prefix caching is enabled, block size must be set to 128."
|
2025-03-28 19:34:23 +08:00
|
|
|
)
|
2025-05-09 16:39:28 +08:00
|
|
|
cache_config.block_size = 128
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
if envs.VLLM_USE_V1:
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
# Activate custom ops for v1, except on 310P
|
|
|
|
|
if not is_310p():
|
|
|
|
|
compilation_config.custom_ops = ["all"]
|
2025-04-29 17:12:03 +08:00
|
|
|
|
2025-06-05 16:28:01 +08:00
|
|
|
# If ascend_scheduler_config is enabled,
|
|
|
|
|
# extents original scheduler_config to use AscendScheduler.
|
|
|
|
|
if ascend_config.ascend_scheduler_config.enabled:
|
2025-04-17 19:31:50 +08:00
|
|
|
from vllm_ascend.core.schedule_config import \
|
|
|
|
|
AscendSchedulerConfig
|
|
|
|
|
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
|
2025-06-05 16:28:01 +08:00
|
|
|
vllm_config.scheduler_config,
|
|
|
|
|
ascend_config.ascend_scheduler_config)
|
2025-04-17 19:31:50 +08:00
|
|
|
vllm_config.scheduler_config = ascend_scheduler_config
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
|
|
|
|
kv_cache_dtype, block_size, use_v1, use_mla):
|
2025-04-19 17:38:18 +08:00
|
|
|
if use_v1 and use_mla:
|
|
|
|
|
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
2025-07-03 22:21:42 +08:00
|
|
|
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
|
|
|
|
if use_v1 and use_torchair:
|
|
|
|
|
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
2025-03-20 19:34:44 +08:00
|
|
|
if use_v1:
|
|
|
|
|
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
2025-02-21 17:07:37 +08:00
|
|
|
if use_mla:
|
2025-03-20 19:34:44 +08:00
|
|
|
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
|
|
|
|
|
return "vllm_ascend.attention.attention.AscendAttentionBackend"
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-04-17 16:48:46 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_punica_wrapper(cls) -> str:
|
|
|
|
|
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_current_memory_usage(cls,
|
|
|
|
|
device: Optional[torch.types.Device] = None
|
|
|
|
|
) -> float:
|
|
|
|
|
torch.npu.reset_peak_memory_stats(device)
|
|
|
|
|
return torch.npu.max_memory_allocated(device)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_device_communicator_cls(cls) -> str:
|
2025-04-15 15:11:35 +08:00
|
|
|
return "vllm_ascend.distributed.communicator.NPUCommunicator"
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_pin_memory_available(cls):
|
|
|
|
|
return True
|
2025-03-28 19:34:23 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
|
|
|
|
"""Returns whether the current platform can support v1 for the supplied
|
|
|
|
|
model configuration.
|
|
|
|
|
"""
|
|
|
|
|
return True
|
2025-05-29 11:58:26 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_piecewise_backend_cls(cls) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Get piecewise backend class for piecewise graph.
|
|
|
|
|
"""
|
|
|
|
|
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
|
2025-06-09 14:08:18 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def stateless_init_device_torch_dist_pg(
|
|
|
|
|
cls,
|
|
|
|
|
backend: str,
|
|
|
|
|
prefix_store: PrefixStore,
|
|
|
|
|
group_rank: int,
|
|
|
|
|
group_size: int,
|
|
|
|
|
timeout: timedelta,
|
|
|
|
|
) -> ProcessGroup:
|
|
|
|
|
from torch.distributed import is_hccl_available
|
|
|
|
|
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
|
|
|
|
|
|
|
|
|
assert is_hccl_available()
|
|
|
|
|
|
|
|
|
|
# TODO(Yizhou): The reason we need to set options while vllm does not
|
|
|
|
|
# seems to be related to the version of PyTorch. In the latest version,
|
|
|
|
|
# there is no need to set options. While in the older version, 2.5.1
|
|
|
|
|
# specifically, we need to set options.
|
|
|
|
|
options = ProcessGroup.Options(backend=backend)
|
|
|
|
|
pg: ProcessGroup = ProcessGroup(
|
|
|
|
|
prefix_store,
|
|
|
|
|
group_rank,
|
|
|
|
|
group_size,
|
|
|
|
|
options,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
backend_options = ProcessGroupHCCL.Options()
|
|
|
|
|
backend_options._timeout = timeout
|
|
|
|
|
|
|
|
|
|
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
|
|
|
|
backend_options)
|
|
|
|
|
device = torch.device("npu")
|
|
|
|
|
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
|
|
|
|
|
# implemented in the 2.5.1 version of PyTorch. But we need to set it
|
|
|
|
|
# after the latest version is released.
|
|
|
|
|
# pg._set_default_backend(backend_type)
|
|
|
|
|
backend_class._set_sequence_number_for_group()
|
|
|
|
|
backend_type = ProcessGroup.BackendType.CUSTOM
|
|
|
|
|
|
|
|
|
|
pg._register_backend(device, backend_type, backend_class)
|
|
|
|
|
return pg
|