2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-05 10:53:12 +08:00
|
|
|
#
|
|
|
|
|
|
2025-04-03 14:52:34 +08:00
|
|
|
import logging
|
2025-02-05 10:53:12 +08:00
|
|
|
import os
|
2025-02-21 17:07:37 +08:00
|
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-03-20 19:34:44 +08:00
|
|
|
import vllm.envs as envs
|
2025-04-15 10:18:05 +08:00
|
|
|
from vllm.logger import logger
|
2025-04-18 08:56:05 +08:00
|
|
|
from vllm.platforms import Platform, PlatformEnum
|
2025-04-03 14:52:34 +08:00
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
|
2025-05-12 20:26:22 +08:00
|
|
|
|
2025-04-18 08:56:05 +08:00
|
|
|
CUSTOM_OP_ENABLED = False
|
2025-04-03 14:52:34 +08:00
|
|
|
try:
|
|
|
|
|
# register custom ops into torch_library here
|
|
|
|
|
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
2025-04-29 17:12:03 +08:00
|
|
|
CUSTOM_OP_ENABLED = True
|
2025-05-23 10:05:57 +08:00
|
|
|
except ImportError as e:
|
|
|
|
|
logging.warning(
|
|
|
|
|
"Failed to import 'vllm_ascend.vllm_ascend_C': %s. All custom ops will be disabled. ",
|
|
|
|
|
e)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
if TYPE_CHECKING:
|
2025-03-28 19:34:23 +08:00
|
|
|
from vllm.config import ModelConfig, VllmConfig
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
else:
|
2025-03-28 19:34:23 +08:00
|
|
|
ModelConfig = None
|
|
|
|
|
VllmConfig = None
|
2025-02-21 17:07:37 +08:00
|
|
|
FlexibleArgumentParser = None
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NPUPlatform(Platform):
|
|
|
|
|
|
|
|
|
|
_enum = PlatformEnum.OOT
|
|
|
|
|
device_name: str = "npu"
|
|
|
|
|
device_type: str = "npu"
|
2025-03-21 15:55:51 +08:00
|
|
|
simple_compile_backend: str = "eager" # Disable torch.compile()
|
2025-02-05 10:53:12 +08:00
|
|
|
ray_device_key: str = "NPU"
|
|
|
|
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
2025-02-21 17:10:30 +08:00
|
|
|
dispatch_key: str = "PrivateUse1"
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
supported_quantization: list[str] = [ASCEND_QUATIZATION_METHOD]
|
2025-02-21 17:07:37 +08:00
|
|
|
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
def is_sleep_mode_available(self) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
@classmethod
|
|
|
|
|
def pre_register_and_update(cls,
|
|
|
|
|
parser: Optional[FlexibleArgumentParser] = None
|
|
|
|
|
) -> None:
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
# Adapt the global patch here.
|
|
|
|
|
from vllm_ascend.utils import adapt_patch
|
|
|
|
|
adapt_patch(is_global_patch=True)
|
|
|
|
|
|
2025-05-17 17:36:04 +08:00
|
|
|
# For online serving, "ascend" quantization method is not a choice natively,
|
|
|
|
|
# so we need to add "ascend" quantization method to quantization methods list
|
|
|
|
|
# and the user can enable quantization using "vllm serve --quantization ascend".
|
|
|
|
|
if parser is not None:
|
|
|
|
|
quant_action = parser._option_string_actions.get('--quantization')
|
|
|
|
|
if quant_action and hasattr(quant_action, 'choices'):
|
|
|
|
|
if ASCEND_QUATIZATION_METHOD not in quant_action.choices:
|
|
|
|
|
quant_action.choices.append(ASCEND_QUATIZATION_METHOD)
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm_ascend.quantization.quant_config import \
|
|
|
|
|
AscendQuantConfig # noqa: F401
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_device_capability(cls, device_id: int = 0):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
2025-03-20 19:34:44 +08:00
|
|
|
return torch.npu.get_device_name(device_id)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def inference_mode(cls):
|
|
|
|
|
return torch.inference_mode()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def set_device(cls, device: torch.device):
|
|
|
|
|
torch.npu.set_device(device)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def empty_cache(cls):
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def synchronize(cls):
|
|
|
|
|
torch.npu.synchronize()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def mem_get_info(cls) -> Tuple[int, int]:
|
|
|
|
|
return torch.npu.mem_get_info()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
2025-04-03 16:03:08 +08:00
|
|
|
from vllm.config import CompilationLevel # noqa: E402
|
2025-03-20 19:34:44 +08:00
|
|
|
compilation_config = vllm_config.compilation_config
|
2025-05-29 11:58:26 +08:00
|
|
|
model_config = vllm_config.model_config
|
2025-05-30 15:17:11 +08:00
|
|
|
additional_config = vllm_config.additional_config
|
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
|
cache_config = vllm_config.cache_config
|
|
|
|
|
|
|
|
|
|
if parallel_config:
|
|
|
|
|
# Default value for expert tensor parallel size
|
|
|
|
|
parallel_config.expert_tensor_parallel_size = parallel_config.tensor_parallel_size
|
|
|
|
|
|
|
|
|
|
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
|
|
|
|
|
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
|
|
|
|
|
if (additional_config
|
|
|
|
|
and "expert_tensor_parallel_size" in additional_config
|
|
|
|
|
and not parallel_config.enable_expert_parallel):
|
|
|
|
|
parallel_config.expert_tensor_parallel_size = int(
|
|
|
|
|
additional_config["expert_tensor_parallel_size"])
|
|
|
|
|
|
|
|
|
|
# Calculate expert parallel size based on world size
|
|
|
|
|
parallel_config.expert_parallel_size = (
|
|
|
|
|
parallel_config.world_size //
|
|
|
|
|
parallel_config.expert_tensor_parallel_size)
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
|
2025-05-29 11:58:26 +08:00
|
|
|
if model_config is None:
|
2025-04-24 17:20:11 +08:00
|
|
|
logger.warning("Model config is missing. This may indicate "
|
|
|
|
|
"that we are running a test case")
|
|
|
|
|
enforce_eager = False
|
|
|
|
|
else:
|
2025-05-29 11:58:26 +08:00
|
|
|
enforce_eager = getattr(model_config, "enforce_eager", False)
|
|
|
|
|
|
2025-05-30 15:17:11 +08:00
|
|
|
if additional_config is not None:
|
|
|
|
|
enable_graph_mode = additional_config.get("enable_graph_mode",
|
|
|
|
|
False)
|
2025-05-29 11:58:26 +08:00
|
|
|
if enable_graph_mode:
|
|
|
|
|
if enforce_eager:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
|
|
|
|
)
|
|
|
|
|
elif envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
|
|
|
|
"it has been disabled automatically.")
|
2025-05-30 15:17:11 +08:00
|
|
|
additional_config["enable_graph_mode"] = False
|
2025-05-29 11:58:26 +08:00
|
|
|
if model_config:
|
|
|
|
|
model_type = model_config.hf_config.model_type
|
|
|
|
|
if "deepseek" not in model_type:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"enable_graph_mode only works with deepseek model."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
|
|
|
|
|
model_type = model_config.hf_config.model_type
|
|
|
|
|
if "deepseek" in model_type:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"ACL Graph does not support deepseek. Please "
|
|
|
|
|
"adopt additional_config={'enable_graph_mode': True} "
|
|
|
|
|
"to serve deepseek models with NPU graph mode on vllm-ascend with V1 engine."
|
|
|
|
|
" Or set `enforce_eager=True` to use eager mode.")
|
|
|
|
|
elif "qwen" not in model_type:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"ACL Graph is currently experimental. Please "
|
|
|
|
|
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
|
|
|
|
" if you encourage any Error")
|
2025-04-24 17:20:11 +08:00
|
|
|
|
|
|
|
|
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
|
|
|
|
|
logger.info("Compilation disabled, using eager mode by default")
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
|
|
|
|
elif compilation_config.level != CompilationLevel.PIECEWISE:
|
2025-03-20 19:34:44 +08:00
|
|
|
logger.warning(
|
2025-04-24 17:20:11 +08:00
|
|
|
"NPU does not support %s compilation level. Setting level to NO_COMPILATION",
|
2025-03-20 19:34:44 +08:00
|
|
|
compilation_config.level)
|
|
|
|
|
compilation_config.level = CompilationLevel.NO_COMPILATION
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
else:
|
|
|
|
|
logger.info(
|
2025-04-24 17:20:11 +08:00
|
|
|
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
|
|
|
|
|
"using only ACL Graph mode")
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
compilation_config.use_inductor = False
|
|
|
|
|
compilation_config.splitting_ops.extend(
|
|
|
|
|
["vllm.unified_ascend_attention_with_output"])
|
2025-05-12 20:26:22 +08:00
|
|
|
update_aclgraph_sizes(vllm_config)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-03-28 16:31:27 +08:00
|
|
|
if parallel_config and parallel_config.worker_cls == "auto":
|
2025-03-20 19:34:44 +08:00
|
|
|
if envs.VLLM_USE_V1:
|
|
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
2025-03-21 15:55:51 +08:00
|
|
|
elif vllm_config.speculative_config:
|
|
|
|
|
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
|
|
|
|
parallel_config.sd_worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
|
|
|
|
elif vllm_config.scheduler_config.is_multi_step:
|
|
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
2025-03-11 19:20:06 +08:00
|
|
|
else:
|
2025-03-21 15:55:51 +08:00
|
|
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
2025-03-11 19:20:06 +08:00
|
|
|
|
2025-03-28 19:34:23 +08:00
|
|
|
if cache_config:
|
|
|
|
|
if cache_config.block_size is None:
|
|
|
|
|
cache_config.block_size = 128
|
2025-05-09 16:39:28 +08:00
|
|
|
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
|
2025-03-28 19:34:23 +08:00
|
|
|
logger.warning(
|
2025-05-09 16:39:28 +08:00
|
|
|
"If prefix caching is enabled, block size must be set to 128."
|
2025-03-28 19:34:23 +08:00
|
|
|
)
|
2025-05-09 16:39:28 +08:00
|
|
|
cache_config.block_size = 128
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
if envs.VLLM_USE_V1:
|
|
|
|
|
# Activate custom ops for v1.
|
2025-05-30 15:17:11 +08:00
|
|
|
compilation_config.custom_ops = ["all"]
|
2025-04-17 19:31:50 +08:00
|
|
|
# If ascend_scheduler_config exists in additional_config,
|
|
|
|
|
# extents original scheduler_config to use AscendScheduler.
|
2025-04-29 17:12:03 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
if additional_config and additional_config.get(
|
|
|
|
|
"ascend_scheduler_config", None) is not None:
|
|
|
|
|
additional_scheduler_config = additional_config.get(
|
|
|
|
|
"ascend_scheduler_config")
|
|
|
|
|
from vllm_ascend.core.schedule_config import \
|
|
|
|
|
AscendSchedulerConfig
|
|
|
|
|
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
|
|
|
|
|
vllm_config.scheduler_config, additional_scheduler_config)
|
|
|
|
|
vllm_config.scheduler_config = ascend_scheduler_config
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
|
|
|
|
kv_cache_dtype, block_size, use_v1, use_mla):
|
2025-04-19 17:38:18 +08:00
|
|
|
if use_v1 and use_mla:
|
|
|
|
|
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
|
2025-03-20 19:34:44 +08:00
|
|
|
if use_v1:
|
|
|
|
|
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
2025-02-21 17:07:37 +08:00
|
|
|
if use_mla:
|
2025-03-20 19:34:44 +08:00
|
|
|
return "vllm_ascend.attention.attention.AscendMLAAttentionBackend"
|
|
|
|
|
return "vllm_ascend.attention.attention.AscendAttentionBackend"
|
2025-02-05 10:53:12 +08:00
|
|
|
|
2025-04-17 16:48:46 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_punica_wrapper(cls) -> str:
|
|
|
|
|
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
@classmethod
|
|
|
|
|
def get_current_memory_usage(cls,
|
|
|
|
|
device: Optional[torch.types.Device] = None
|
|
|
|
|
) -> float:
|
|
|
|
|
torch.npu.reset_peak_memory_stats(device)
|
|
|
|
|
return torch.npu.max_memory_allocated(device)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_device_communicator_cls(cls) -> str:
|
2025-04-15 15:11:35 +08:00
|
|
|
return "vllm_ascend.distributed.communicator.NPUCommunicator"
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_pin_memory_available(cls):
|
|
|
|
|
return True
|
2025-03-28 19:34:23 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
|
|
|
|
"""Returns whether the current platform can support v1 for the supplied
|
|
|
|
|
model configuration.
|
|
|
|
|
"""
|
|
|
|
|
return True
|
2025-05-29 11:58:26 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_piecewise_backend_cls(cls) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Get piecewise backend class for piecewise graph.
|
|
|
|
|
"""
|
|
|
|
|
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
|