2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
import gc
|
2025-09-29 09:12:49 +08:00
|
|
|
import os
|
2025-02-21 17:07:37 +08:00
|
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-04-15 10:18:05 +08:00
|
|
|
from vllm.logger import logger
|
2025-04-18 08:56:05 +08:00
|
|
|
from vllm.platforms import Platform, PlatformEnum
|
2025-04-03 14:52:34 +08:00
|
|
|
|
2025-10-25 15:36:32 +08:00
|
|
|
# todo: please remove it when solve cuda hard code in vllm
|
2025-11-26 11:48:58 +08:00
|
|
|
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
2025-10-25 15:36:32 +08:00
|
|
|
|
2025-07-03 22:21:42 +08:00
|
|
|
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
|
|
|
|
init_ascend_config)
|
2025-09-03 17:56:12 +08:00
|
|
|
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
|
|
|
|
delete_torchair_cache_file)
|
[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)
### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.
Co-authored-by: taoqun110 taoqun@huawei.com
Co-authored-by: chenxi-hh chen464822955@163.com
- vLLM version: v0.11.2
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
Co-authored-by: chenxi-hh <chen464822955@163.com>
Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
2025-11-28 14:09:39 +08:00
|
|
|
|
|
|
|
|
# isort: off
|
|
|
|
|
from vllm_ascend.utils import (
|
|
|
|
|
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
|
|
|
|
enable_sp, get_ascend_device_type, is_vl_model,
|
|
|
|
|
prefill_context_parallel_enable, update_aclgraph_sizes,
|
|
|
|
|
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
|
2025-05-12 20:26:22 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
if TYPE_CHECKING:
|
2025-03-28 19:34:23 +08:00
|
|
|
from vllm.config import ModelConfig, VllmConfig
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
else:
|
2025-03-28 19:34:23 +08:00
|
|
|
ModelConfig = None
|
|
|
|
|
VllmConfig = None
|
2025-02-21 17:07:37 +08:00
|
|
|
FlexibleArgumentParser = None
|
|
|
|
|
|
2025-12-03 19:27:38 +08:00
|
|
|
CUSTOM_OP_REGISTERED = False
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
class NPUPlatform(Platform):
|
|
|
|
|
|
|
|
|
|
_enum = PlatformEnum.OOT
|
|
|
|
|
device_name: str = "npu"
|
|
|
|
|
device_type: str = "npu"
|
2025-03-21 15:55:51 +08:00
|
|
|
simple_compile_backend: str = "eager" # Disable torch.compile()
|
2025-02-05 10:53:12 +08:00
|
|
|
ray_device_key: str = "NPU"
|
|
|
|
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
2025-02-21 17:10:30 +08:00
|
|
|
dispatch_key: str = "PrivateUse1"
|
2025-02-05 10:53:12 +08:00
|
|
|
|
[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)
### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.
Co-authored-by: taoqun110 taoqun@huawei.com
Co-authored-by: chenxi-hh chen464822955@163.com
- vLLM version: v0.11.2
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
Co-authored-by: chenxi-hh <chen464822955@163.com>
Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
2025-11-28 14:09:39 +08:00
|
|
|
supported_quantization: list[str] = [
|
|
|
|
|
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
|
|
|
|
|
]
|
2025-02-21 17:07:37 +08:00
|
|
|
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
def is_sleep_mode_available(self) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
@classmethod
|
|
|
|
|
def pre_register_and_update(cls,
|
|
|
|
|
parser: Optional[FlexibleArgumentParser] = None
|
|
|
|
|
) -> None:
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
# Adapt the global patch here.
|
|
|
|
|
from vllm_ascend.utils import adapt_patch
|
|
|
|
|
adapt_patch(is_global_patch=True)
|
|
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
# For online serving, "ascend" quantization method is not a choice natively,
|
|
|
|
|
# so we need to add "ascend" quantization method to quantization methods list
|
|
|
|
|
# and the user can enable quantization using "vllm serve --quantization ascend".
|
|
|
|
|
if parser is not None:
|
|
|
|
|
quant_action = parser._option_string_actions.get('--quantization')
|
[Bugfix] Add verification for `quant_action.choices` to avoid `TypeError` (#1046)
### What this PR does / why we need it?
When I run vllm-ascend, I get this error msg:
```bash
Traceback (most recent call last):
File "/home/sss/software/miniconda3/envs/vllm-v1/bin/vllm", line 8, in <module>
sys.exit(main())
File "/home/sss/github/vllm-project/vllm/vllm/entrypoints/cli/main.py", line 50, in main
cmd.subparser_init(subparsers).set_defaults(
File "/home/sss/github/vllm-project/vllm/vllm/entrypoints/cli/serve.py", line 101, in subparser_init
serve_parser = make_arg_parser(serve_parser)
File "/home/sss/github/vllm-project/vllm/vllm/entrypoints/openai/cli_args.py", line 254, in make_arg_parser
parser = AsyncEngineArgs.add_cli_args(parser)
File "/home/sss/github/vllm-project/vllm/vllm/engine/arg_utils.py", line 1582, in add_cli_args
current_platform.pre_register_and_update(parser)
File "/home/sss/github/vllm-project/vllm-ascend/vllm_ascend/platform.py", line 80, in pre_register_and_update
if ASCEND_QUATIZATION_METHOD not in quant_action.choices:
TypeError: argument of type 'NoneType' is not iterable
[ERROR] 2025-06-03-02:53:42 (PID:6005, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception
```
This is because the `choices` attribute in `quant_action` can be `None`
and we don't check it.
```bash
# quant_action
_StoreAction(option_strings=['--quantization', '-q'], dest='quantization', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, required=False, help='Method used to quantize the weights. If `None`, we first check the\n`quantization_config` attribute in the model config file. If that is\n`None`, we assume the model weights are not quantized and use `dtype` to\ndetermine the data type of the weights.', metavar=None)
```
Thus, I have added check for the `choices` to handle the scenario of
`choices=None`.
### Does this PR introduce _any_ user-facing change?
yes, vllm server with ascend quantization works now.
### How was this patch tested?
by `vllm server --quantization ascend` command.
Related: https://github.com/vllm-project/vllm/issues/19004
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-06-03 11:44:45 +08:00
|
|
|
if quant_action and hasattr(quant_action,
|
|
|
|
|
'choices') and quant_action.choices:
|
2025-08-26 09:06:16 +08:00
|
|
|
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
|
|
|
|
|
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
|
2025-05-17 17:36:04 +08:00
|
|
|
|
[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)
### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.
Co-authored-by: taoqun110 taoqun@huawei.com
Co-authored-by: chenxi-hh chen464822955@163.com
- vLLM version: v0.11.2
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
Co-authored-by: chenxi-hh <chen464822955@163.com>
Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
2025-11-28 14:09:39 +08:00
|
|
|
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \
|
|
|
|
|
AscendCompressedTensorsConfig # noqa: F401
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm_ascend.quantization.quant_config import \
|
|
|
|
|
AscendQuantConfig # noqa: F401
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_device_capability(cls, device_id: int = 0):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
2025-03-20 19:34:44 +08:00
|
|
|
return torch.npu.get_device_name(device_id)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def inference_mode(cls):
|
|
|
|
|
return torch.inference_mode()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def set_device(cls, device: torch.device):
|
|
|
|
|
torch.npu.set_device(device)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def empty_cache(cls):
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def synchronize(cls):
|
|
|
|
|
torch.npu.synchronize()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
|
|
|
|
return torch.npu.mem_get_info()
|
|
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
@classmethod
|
|
|
|
|
def clear_npu_memory(cls):
|
|
|
|
|
gc.collect()
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
torch.npu.reset_peak_memory_stats()
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
2025-06-05 16:28:01 +08:00
|
|
|
# initialize ascend config from vllm additional_config
|
|
|
|
|
ascend_config = init_ascend_config(vllm_config)
|
|
|
|
|
|
2025-11-24 17:08:20 +08:00
|
|
|
from vllm.config import CompilationMode # noqa: E402
|
[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
1. refactor deepseek to the latest code arch as of
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
compilation_config = vllm_config.compilation_config
|
2025-05-29 11:58:26 +08:00
|
|
|
model_config = vllm_config.model_config
|
2025-05-30 15:17:11 +08:00
|
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
|
cache_config = vllm_config.cache_config
|
2025-11-29 22:20:48 +08:00
|
|
|
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
2025-09-12 23:17:09 +08:00
|
|
|
|
2025-06-28 18:51:07 +08:00
|
|
|
kv_cache_dtype = vllm_config.additional_config.get(
|
|
|
|
|
"kv_cache_dtype", None)
|
|
|
|
|
if kv_cache_dtype is not None:
|
|
|
|
|
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
2025-10-15 17:48:58 +08:00
|
|
|
elif model_config and hasattr(model_config.hf_config, "index_topk"):
|
|
|
|
|
vllm_config.cache_config.cache_dtype = str(
|
|
|
|
|
model_config.dtype).replace("torch.", "")
|
2025-05-29 11:58:26 +08:00
|
|
|
if model_config is None:
|
2025-04-24 17:20:11 +08:00
|
|
|
logger.warning("Model config is missing. This may indicate "
|
|
|
|
|
"that we are running a test case")
|
|
|
|
|
enforce_eager = False
|
|
|
|
|
else:
|
2025-05-29 11:58:26 +08:00
|
|
|
enforce_eager = getattr(model_config, "enforce_eager", False)
|
|
|
|
|
|
2025-06-05 16:28:01 +08:00
|
|
|
check_ascend_config(vllm_config, enforce_eager)
|
2025-08-20 09:01:04 +08:00
|
|
|
from vllm.config.compilation import CUDAGraphMode
|
2025-08-27 09:30:25 +08:00
|
|
|
if enforce_eager:
|
2025-04-24 17:20:11 +08:00
|
|
|
logger.info("Compilation disabled, using eager mode by default")
|
2025-11-24 17:08:20 +08:00
|
|
|
compilation_config.mode = CompilationMode.NONE
|
2025-11-26 11:48:58 +08:00
|
|
|
if compilation_config.splitting_ops is None:
|
|
|
|
|
compilation_config.splitting_ops = []
|
2025-08-27 09:30:25 +08:00
|
|
|
|
|
|
|
|
compilation_config.cudagraph_num_of_warmups = 1
|
2025-12-03 23:43:05 +08:00
|
|
|
compilation_config.pass_config.fuse_norm_quant = False
|
|
|
|
|
compilation_config.pass_config.fuse_act_quant = False
|
2025-08-27 09:30:25 +08:00
|
|
|
|
2025-11-24 17:08:20 +08:00
|
|
|
if compilation_config.mode not in [
|
|
|
|
|
CompilationMode.NONE, CompilationMode.VLLM_COMPILE
|
|
|
|
|
]:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"NPU does not support %s compilation mode. Setting CUDAGraphMode to NONE",
|
|
|
|
|
compilation_config.mode)
|
|
|
|
|
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
2025-08-27 09:30:25 +08:00
|
|
|
|
|
|
|
|
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
|
|
|
|
|
if ascend_config.torchair_graph_config.enabled:
|
2025-06-05 16:28:01 +08:00
|
|
|
logger.info(
|
2025-08-27 09:30:25 +08:00
|
|
|
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
|
2025-06-05 16:28:01 +08:00
|
|
|
)
|
2025-08-20 09:01:04 +08:00
|
|
|
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
2025-09-03 17:56:12 +08:00
|
|
|
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
|
|
|
|
|
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
|
|
|
|
|
# this will increase graph compilation duration, it significantly enhances robustness and decreases
|
|
|
|
|
# graph launching time during inference.
|
|
|
|
|
if check_torchair_cache_exist(
|
|
|
|
|
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
|
|
|
|
|
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
|
|
|
|
|
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
|
|
|
|
|
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
|
|
|
|
delete_torchair_cache_file()
|
2025-08-27 09:30:25 +08:00
|
|
|
|
|
|
|
|
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
|
|
|
|
vllm_config._set_cudagraph_sizes()
|
[main][misc]change default capture size for Qwen3-MoE when using full dp (#4199)
### What this PR does / why we need it?
Currently, the default `cudagraph_capture_size` in vLLM is `[1, 2, 4 ,8
,16 ,24 ,... , max_capture_size]`. However, this is not always the best
choice on different situations. This PR aims to change the default
setting when running Qwen3-MoE on full dp (`dp_size > 1` && `tp_size ==
1`) setting, which is usually applied in Large-Scale EP.
old :
`[1, 2, 4 ,8 ,16 ,24 ,... , max_capture_size]`
new:
`[1, 2, 5 ,10 ,15, 16 ,24 ,... , max_capture_size]`
This is mainly because the performance of `_npu_paged_attention` op
degrades dramatically on old settings. We hope to provide better
performance if users do not set specific `cudagraph_capture_size`.
### Does this PR introduce _any_ user-facing change?
The default `cudagraph_capture_size` is modified in above cases.
However, if `cudagraph_capture_size` has already set by users, this PR
won't have any influence on this.
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: Angazenn <supperccell@163.com>
2025-11-18 08:41:45 +08:00
|
|
|
# There are cases where default cudagraph_capture_sizes are not friendly
|
|
|
|
|
# to ascend ops && hardwares. We update these sizes here to improve
|
|
|
|
|
# default performance.
|
|
|
|
|
update_default_aclgraph_sizes(vllm_config)
|
2025-12-03 23:43:05 +08:00
|
|
|
# TODO delete graph size update here when compilation_config.pass_config.enable_sp
|
2025-10-15 19:36:32 +08:00
|
|
|
# is supported by vllm-ascend.
|
|
|
|
|
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
|
|
|
|
|
enable_sp(vllm_config):
|
|
|
|
|
original_sizes = compilation_config.cudagraph_capture_sizes
|
|
|
|
|
sp_aclgraph_sizes = \
|
|
|
|
|
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
|
|
|
|
|
assert sp_aclgraph_sizes, (
|
|
|
|
|
f"cudagraph_capture_sizes {original_sizes} does not contain"
|
|
|
|
|
f"values that are multiples of tp_size "
|
|
|
|
|
f"{vllm_config.parallel_config.tensor_parallel_size}")
|
|
|
|
|
if len(sp_aclgraph_sizes) != len(original_sizes):
|
|
|
|
|
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
|
2025-11-24 17:08:20 +08:00
|
|
|
update_cudagraph_capture_sizes(vllm_config, sp_aclgraph_sizes)
|
2025-08-27 09:30:25 +08:00
|
|
|
|
2025-09-26 06:18:15 +08:00
|
|
|
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
|
2025-10-09 10:28:38 +08:00
|
|
|
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
|
|
|
|
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
2025-09-26 06:18:15 +08:00
|
|
|
|
2025-11-24 17:08:20 +08:00
|
|
|
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
|
|
|
|
compilation_config.mode = CompilationMode.NONE
|
|
|
|
|
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
|
|
|
|
logger.info(
|
|
|
|
|
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
|
|
|
|
"using only ACL Graph mode")
|
|
|
|
|
assert compilation_config.mode == CompilationMode.VLLM_COMPILE, \
|
|
|
|
|
"When enabling VLLM_COMPILE aclgraph, please make sure compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_config.cudagraph_mode == CUDAGraphMode.VLLM_COMPILE"
|
|
|
|
|
compilation_config.set_splitting_ops_for_v1()
|
|
|
|
|
compilation_config.use_inductor = False
|
|
|
|
|
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
|
|
|
|
update_aclgraph_sizes(vllm_config)
|
|
|
|
|
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
|
|
|
|
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
|
|
|
|
logger.info(
|
|
|
|
|
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
|
|
|
|
"using only ACL Graph mode")
|
|
|
|
|
compilation_config.use_inductor = False
|
|
|
|
|
warning_message = """\033[91m
|
|
|
|
|
**********************************************************************************
|
|
|
|
|
* WARNING: You have enabled the *full graph* feature.
|
|
|
|
|
* This is an early experimental stage and may involve various unknown issues.
|
|
|
|
|
* A known problem is that capturing too many batch sizes can lead to OOM
|
|
|
|
|
* (Out of Memory) errors or inference hangs. If you encounter such issues,
|
|
|
|
|
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
|
|
|
|
|
* batch size for graph capture.
|
|
|
|
|
* For more details, please refer to:
|
|
|
|
|
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
|
|
|
|
|
**********************************************************************************\033[0m
|
|
|
|
|
"""
|
|
|
|
|
logger.warning(warning_message)
|
2025-08-27 09:30:25 +08:00
|
|
|
else:
|
2025-11-24 17:08:20 +08:00
|
|
|
logger.info(
|
|
|
|
|
"%s cudagraph_mode is not support on NPU. falling back to NONE",
|
|
|
|
|
compilation_config.cudagraph_mode)
|
|
|
|
|
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
|
|
|
|
compilation_config.mode = CompilationMode.NONE
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-10-21 20:17:33 +08:00
|
|
|
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
|
|
|
|
|
# Then, we will have to discuss the error handling strategy and user experience
|
|
|
|
|
if compilation_config.cudagraph_mode != CUDAGraphMode.NONE and \
|
|
|
|
|
os.environ.get("ASCEND_LAUNCH_BLOCKING", "0") == "1":
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"ACL graph is incompatible with ASCEND_LAUNCH_BLOCKING=1. "
|
|
|
|
|
"Please unset ASCEND_LAUNCH_BLOCKING or set it to 0. If you "
|
|
|
|
|
"need ASCEND_LAUNCH_BLOCKING for debugging, consider other methods — "
|
|
|
|
|
"for example, check the plog files (default: $HOME/ascend/log/debug) "
|
|
|
|
|
"for more information about runtime errors.")
|
|
|
|
|
|
2025-03-28 16:31:27 +08:00
|
|
|
if parallel_config and parallel_config.worker_cls == "auto":
|
2025-09-29 09:12:49 +08:00
|
|
|
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
|
2025-11-24 17:08:20 +08:00
|
|
|
parallel_config.all2all_backend = "flashinfer_all2allv"
|
2025-12-01 20:44:11 +08:00
|
|
|
if ascend_config.torchair_graph_config.enabled:
|
2025-07-21 11:50:46 +08:00
|
|
|
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
|
|
|
|
|
else:
|
|
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-03-28 19:34:23 +08:00
|
|
|
if cache_config:
|
|
|
|
|
if cache_config.block_size is None:
|
|
|
|
|
cache_config.block_size = 128
|
2025-11-29 22:20:48 +08:00
|
|
|
|
|
|
|
|
if cache_config.enable_prefix_caching or \
|
|
|
|
|
not ascend_scheduler_config.enabled or \
|
|
|
|
|
getattr(ascend_scheduler_config, "enable_chunked_prefill", False):
|
|
|
|
|
logger.warning(
|
|
|
|
|
"If chunked prefill or prefix caching is enabled, block size must be set to 128."
|
|
|
|
|
)
|
|
|
|
|
origin_block_size = cache_config.block_size
|
|
|
|
|
cache_config.block_size = 128
|
|
|
|
|
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
|
|
|
|
|
if model_config and model_config.hf_config.model_type == "qwen3_next":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"When running qwen3-next model, block_size needs to be restored to its original value."
|
|
|
|
|
)
|
|
|
|
|
cache_config.block_size = origin_block_size
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-07-15 11:52:16 +08:00
|
|
|
# Activate custom ops for v1, except on 310P
|
[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
|
|
|
if get_ascend_device_type() != AscendDeviceType._310P:
|
2025-07-15 11:52:16 +08:00
|
|
|
compilation_config.custom_ops = ["all"]
|
|
|
|
|
|
2025-11-29 22:20:48 +08:00
|
|
|
# If ascend_scheduler_config is enabled,
|
|
|
|
|
# extents original scheduler_config to use AscendScheduler.
|
|
|
|
|
if ascend_config.ascend_scheduler_config.enabled:
|
|
|
|
|
from vllm_ascend.core.schedule_config import AscendSchedulerConfig
|
|
|
|
|
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
|
|
|
|
|
vllm_config.scheduler_config,
|
|
|
|
|
ascend_config.ascend_scheduler_config)
|
|
|
|
|
vllm_config.scheduler_config = ascend_scheduler_config
|
|
|
|
|
elif ascend_config.recompute_scheduler_enable:
|
2025-10-18 15:56:44 +08:00
|
|
|
from vllm_ascend.core.recompute_schedule_config import \
|
|
|
|
|
RecomputeSchedulerConfig
|
|
|
|
|
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
|
|
|
|
|
vllm_config.scheduler_config)
|
|
|
|
|
vllm_config.scheduler_config = recompute_scheduler_config
|
2025-04-17 19:31:50 +08:00
|
|
|
|
2025-10-22 14:13:32 +08:00
|
|
|
# Extend original scheduler_config to use SchedulerDynamicBatch.
|
|
|
|
|
if ascend_config.SLO_limits_for_dynamic_batch != -1:
|
|
|
|
|
vllm_config.scheduler_config.scheduler_cls = (
|
|
|
|
|
"vllm_ascend.core.scheduler_dynamic_batch.SchedulerDynamicBatch"
|
|
|
|
|
)
|
2025-12-02 22:10:52 +08:00
|
|
|
vllm_config.scheduler_config.enable_chunked_prefill = True
|
2025-10-22 14:13:32 +08:00
|
|
|
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
|
|
|
|
|
|
2025-10-24 10:32:01 +08:00
|
|
|
if vllm_config.kv_transfer_config is not None and \
|
|
|
|
|
prefill_context_parallel_enable() and \
|
|
|
|
|
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \
|
|
|
|
|
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1:
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
|
|
|
|
|
f"and block_size({cache_config.block_size}) "
|
|
|
|
|
"needs to be equal if use cp or dcp > 1 in P/D disaggregate scenario."
|
|
|
|
|
)
|
|
|
|
|
|
2025-11-21 15:04:18 +08:00
|
|
|
if is_vl_model(vllm_config):
|
|
|
|
|
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))) or \
|
|
|
|
|
bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Currently, VL models doesn't support "
|
|
|
|
|
"FLASHCOMM in vllm-ascend. We will fix this in the future. "
|
|
|
|
|
"Please set VLLM_ASCEND_ENABLE_FLASHCOMM1=0.")
|
|
|
|
|
|
2025-11-14 11:32:51 +08:00
|
|
|
@classmethod
|
|
|
|
|
def import_kernels(cls) -> None:
|
|
|
|
|
# Directly importing vllm_ascend_C prevents ASCEND_RT_VISIBLE_DEVICES
|
|
|
|
|
# from being applied during runtime initialization, which causes bugs
|
|
|
|
|
# in the RL module. Therefore, we currently use lazy initialization
|
|
|
|
|
# to avoid this issue. See https://github.com/vllm-project/vllm-ascend/pull/884.
|
|
|
|
|
# TODO: when the above issue is fixed, we can uncomment the following lines.
|
|
|
|
|
# from vllm_ascend.utils import enable_custom_op
|
|
|
|
|
# enable_custom_op()
|
2025-12-03 19:27:38 +08:00
|
|
|
# set custom ops path
|
|
|
|
|
global CUSTOM_OP_REGISTERED
|
|
|
|
|
if CUSTOM_OP_REGISTERED:
|
|
|
|
|
return
|
|
|
|
|
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
|
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors",
|
|
|
|
|
"vllm-ascend")
|
|
|
|
|
if os.path.exists(CUSTOM_OPP_PATH):
|
|
|
|
|
current_cust_opp_path = os.environ.get("ASCEND_CUSTOM_OPP_PATH",
|
|
|
|
|
"")
|
|
|
|
|
if current_cust_opp_path:
|
|
|
|
|
os.environ[
|
|
|
|
|
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
|
|
|
|
else:
|
|
|
|
|
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
|
|
|
|
|
CUSTOM_OP_REGISTERED = True
|
2025-11-14 11:32:51 +08:00
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
2025-10-15 17:48:58 +08:00
|
|
|
def get_attn_backend_cls(
|
|
|
|
|
cls,
|
|
|
|
|
selected_backend,
|
|
|
|
|
head_size,
|
|
|
|
|
dtype,
|
|
|
|
|
kv_cache_dtype,
|
|
|
|
|
block_size,
|
|
|
|
|
use_mla,
|
|
|
|
|
has_sink=False,
|
|
|
|
|
use_sparse=False,
|
2025-11-26 11:48:58 +08:00
|
|
|
attn_type: str | None = None,
|
2025-10-15 17:48:58 +08:00
|
|
|
):
|
2025-09-16 14:13:07 +08:00
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
|
|
|
|
|
if use_mla and ascend_config.enable_shared_expert_dp:
|
2025-10-15 17:48:58 +08:00
|
|
|
if use_mla and use_sparse:
|
2025-09-30 03:25:58 +08:00
|
|
|
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
2025-09-16 14:13:07 +08:00
|
|
|
|
|
|
|
|
use_torchair = ascend_config.torchair_graph_config.enabled
|
2025-08-21 14:02:30 +08:00
|
|
|
# choose attention backend based on use_mla and use_torchair
|
|
|
|
|
backend_map = {
|
2025-09-30 03:25:58 +08:00
|
|
|
(True, False, True):
|
2025-08-21 14:02:30 +08:00
|
|
|
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
2025-09-30 03:25:58 +08:00
|
|
|
(True, False, False):
|
2025-08-21 14:02:30 +08:00
|
|
|
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
2025-09-30 03:25:58 +08:00
|
|
|
(False, False, True):
|
2025-08-21 14:02:30 +08:00
|
|
|
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
2025-09-30 03:25:58 +08:00
|
|
|
(False, False, False):
|
|
|
|
|
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
|
|
|
|
(True, True, False):
|
|
|
|
|
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
|
|
|
|
(True, True, True):
|
|
|
|
|
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
2025-08-21 14:02:30 +08:00
|
|
|
}
|
2025-10-15 17:48:58 +08:00
|
|
|
return backend_map[(use_mla, use_sparse, use_torchair)]
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-04-17 16:48:46 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_punica_wrapper(cls) -> str:
|
2025-11-24 17:08:20 +08:00
|
|
|
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
|
2025-04-17 16:48:46 +08:00
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_current_memory_usage(cls,
|
|
|
|
|
device: Optional[torch.types.Device] = None
|
|
|
|
|
) -> float:
|
|
|
|
|
torch.npu.reset_peak_memory_stats(device)
|
|
|
|
|
return torch.npu.max_memory_allocated(device)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_device_communicator_cls(cls) -> str:
|
2025-04-15 15:11:35 +08:00
|
|
|
return "vllm_ascend.distributed.communicator.NPUCommunicator"
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_pin_memory_available(cls):
|
|
|
|
|
return True
|
2025-03-28 19:34:23 +08:00
|
|
|
|
2025-10-25 08:58:35 +08:00
|
|
|
@classmethod
|
|
|
|
|
def opaque_attention_op(cls) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
2025-05-29 11:58:26 +08:00
|
|
|
@classmethod
|
2025-08-20 09:01:04 +08:00
|
|
|
def get_static_graph_wrapper_cls(cls) -> str:
|
2025-05-29 11:58:26 +08:00
|
|
|
"""
|
|
|
|
|
Get piecewise backend class for piecewise graph.
|
|
|
|
|
"""
|
2025-08-20 09:01:04 +08:00
|
|
|
return "vllm_ascend.compilation.acl_graph.ACLGraphWrapper" # noqa
|
2025-06-09 14:08:18 +08:00
|
|
|
|
2025-09-16 01:17:42 +08:00
|
|
|
@classmethod
|
|
|
|
|
def support_hybrid_kv_cache(cls) -> bool:
|
|
|
|
|
return True
|
2025-09-22 17:14:28 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def support_static_graph_mode(cls) -> bool:
|
|
|
|
|
return True
|