2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
|
2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
|
2025-08-04 10:08:58 +08:00
|
|
|
import copy
|
2026-01-07 09:25:55 +08:00
|
|
|
import gc
|
2025-11-26 11:48:58 +08:00
|
|
|
from types import NoneType
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch_npu
|
2025-08-14 09:33:39 +08:00
|
|
|
import vllm.envs as envs_vllm
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
2025-09-25 14:15:02 +08:00
|
|
|
from torch_npu.profiler import dynamic_profile as dp
|
2026-01-07 18:42:55 +08:00
|
|
|
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
|
2026-02-06 15:35:06 +08:00
|
|
|
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
|
2025-12-03 20:48:45 +08:00
|
|
|
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
2026-02-06 15:35:06 +08:00
|
|
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized, get_kv_transfer_group, has_kv_transfer_group
|
2026-03-15 09:45:09 +08:00
|
|
|
from vllm.distributed.parallel_state import Handle, get_pp_group, get_tp_group
|
2025-04-15 10:18:05 +08:00
|
|
|
from vllm.logger import logger
|
2025-05-22 19:20:51 +08:00
|
|
|
from vllm.lora.request import LoRARequest
|
2025-07-11 15:30:51 +08:00
|
|
|
from vllm.sequence import IntermediateTensors
|
2025-08-15 07:35:27 +08:00
|
|
|
from vllm.tasks import SupportedTask
|
2025-11-24 17:08:20 +08:00
|
|
|
from vllm.utils.mem_constants import GiB_bytes
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
2025-11-24 17:08:20 +08:00
|
|
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
2025-11-26 11:48:58 +08:00
|
|
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
2025-06-06 21:54:02 +08:00
|
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
2026-02-06 15:35:06 +08:00
|
|
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
2026-03-15 09:45:09 +08:00
|
|
|
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
|
2025-03-20 19:34:44 +08:00
|
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
2025-12-30 08:32:14 +08:00
|
|
|
from vllm.v1.worker.workspace import init_workspace_manager
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-09-25 14:15:02 +08:00
|
|
|
import vllm_ascend.envs as envs_ascend
|
2025-10-21 09:17:03 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
2026-01-07 09:11:26 +08:00
|
|
|
from vllm_ascend.batch_invariant import init_batch_invariance
|
2025-10-21 09:17:03 +08:00
|
|
|
from vllm_ascend.cpu_binding import bind_cpus
|
2025-06-06 21:54:02 +08:00
|
|
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
2025-07-28 14:06:20 +08:00
|
|
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
2025-12-02 17:10:19 +08:00
|
|
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
2026-02-06 15:35:06 +08:00
|
|
|
from vllm_ascend.utils import (
|
|
|
|
|
AscendDeviceType,
|
|
|
|
|
check_ascend_device_type,
|
2026-03-01 20:22:50 +08:00
|
|
|
enable_sp,
|
2026-02-06 15:35:06 +08:00
|
|
|
get_ascend_device_type,
|
|
|
|
|
register_ascend_customop,
|
|
|
|
|
)
|
2025-03-20 19:34:44 +08:00
|
|
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
|
|
|
|
2025-09-11 21:20:09 +08:00
|
|
|
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
|
|
|
|
|
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
|
2026-02-06 15:35:06 +08:00
|
|
|
from vllm.utils.torch_utils import set_random_seed # noqa: E402
|
2026-01-07 09:25:55 +08:00
|
|
|
|
2025-09-11 21:20:09 +08:00
|
|
|
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
|
|
|
|
|
["torch.npu.current_stream"],
|
|
|
|
|
TorchInGraphFunctionVariable,
|
|
|
|
|
) # noqa: E402
|
2026-02-06 15:35:06 +08:00
|
|
|
torch_non_c_binding_in_graph_functions_npu["torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
|
|
|
|
|
torch._dynamo.trace_rules.torch_name_rule_map.append(torch_non_c_binding_in_graph_functions_npu) # noqa: E402
|
2025-09-11 21:20:09 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
class NPUWorker(WorkerBase):
|
2025-04-15 10:24:02 +08:00
|
|
|
def __init__(
|
2026-02-06 15:35:06 +08:00
|
|
|
self,
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
local_rank: int,
|
|
|
|
|
rank: int,
|
|
|
|
|
distributed_init_method: str,
|
|
|
|
|
is_driver_worker: bool = False,
|
|
|
|
|
# Additional parameters for compatibility with vllm
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
2025-04-15 10:24:02 +08:00
|
|
|
"""Initialize the worker for Ascend."""
|
2026-02-06 10:28:42 +08:00
|
|
|
if not envs_ascend.COMPILE_CUSTOM_KERNELS:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"COMPILE_CUSTOM_KERNELS is set to False. "
|
|
|
|
|
"In most scenarios, without custom kernels, vllm-ascend will not function correctly."
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-16 09:28:58 +08:00
|
|
|
# register patch for vllm
|
|
|
|
|
from vllm_ascend.utils import adapt_patch
|
2026-02-06 15:35:06 +08:00
|
|
|
|
2025-04-16 09:28:58 +08:00
|
|
|
adapt_patch()
|
2026-02-06 15:35:06 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
# Register ops when worker init.
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
from vllm_ascend import ops
|
2026-02-06 15:35:06 +08:00
|
|
|
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
ops.register_dummy_fusion_op()
|
2025-12-17 14:08:19 +08:00
|
|
|
if get_ascend_device_type() != AscendDeviceType.A5:
|
2025-12-12 15:50:57 +08:00
|
|
|
_register_atb_extensions()
|
2025-09-11 23:14:02 +08:00
|
|
|
register_ascend_customop(vllm_config)
|
2025-08-05 08:39:02 +08:00
|
|
|
# init ascend config and soc version
|
2025-06-05 16:28:01 +08:00
|
|
|
init_ascend_config(vllm_config)
|
[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-11-26 14:28:55 +08:00
|
|
|
check_ascend_device_type()
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
super().__init__(
|
|
|
|
|
vllm_config=vllm_config,
|
|
|
|
|
local_rank=local_rank,
|
|
|
|
|
rank=rank,
|
|
|
|
|
distributed_init_method=distributed_init_method,
|
|
|
|
|
is_driver_worker=is_driver_worker,
|
|
|
|
|
)
|
2025-06-09 14:08:18 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
if self.cache_config.cache_dtype == "auto":
|
|
|
|
|
self.cache_dtype = self.model_config.dtype
|
|
|
|
|
else:
|
2026-02-06 15:35:06 +08:00
|
|
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2026-03-05 16:18:34 +08:00
|
|
|
# Profiler is lazily initialized on first profile(is_start=True) call (RFC #6954)
|
|
|
|
|
self.profiler_config = vllm_config.profiler_config
|
|
|
|
|
self.profiler = None
|
2025-12-10 23:48:03 +08:00
|
|
|
if vllm_config.model_config and vllm_config.model_config.enable_sleep_mode:
|
2025-09-18 19:51:52 +08:00
|
|
|
# Buffers saved before sleep
|
|
|
|
|
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-09-26 06:18:15 +08:00
|
|
|
# FixMe: this is a patch to fix the issue cause by https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170
|
2026-02-06 15:35:06 +08:00
|
|
|
from vllm.model_executor.layers.linear import WEIGHT_LOADER_V2_SUPPORTED
|
|
|
|
|
|
2025-09-26 06:18:15 +08:00
|
|
|
if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED:
|
|
|
|
|
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
|
|
|
|
|
|
2025-12-18 15:51:54 +08:00
|
|
|
self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER
|
2026-03-15 09:45:09 +08:00
|
|
|
self._pp_send_work: list[Handle] = []
|
2025-12-18 15:51:54 +08:00
|
|
|
|
2026-03-02 17:54:25 +08:00
|
|
|
ascend_compilation_config = get_ascend_config().ascend_compilation_config
|
|
|
|
|
if ascend_compilation_config.enable_npugraph_ex and ascend_compilation_config.enable_static_kernel:
|
2026-01-26 15:03:18 +08:00
|
|
|
# Prevent duplicate triggers, execute the exit logic only once
|
|
|
|
|
shutdown_request = False
|
|
|
|
|
|
|
|
|
|
def signal_handler(signum, frame):
|
|
|
|
|
nonlocal shutdown_request
|
|
|
|
|
if not shutdown_request:
|
|
|
|
|
shutdown_request = True
|
|
|
|
|
self.uninstall_static_kernel()
|
|
|
|
|
raise SystemExit()
|
|
|
|
|
|
|
|
|
|
# Either SIGTERM or SIGINT will terminate the worker
|
|
|
|
|
import signal
|
2026-02-06 15:35:06 +08:00
|
|
|
|
2026-01-26 15:03:18 +08:00
|
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
|
|
|
|
|
|
def uninstall_static_kernel(self):
|
|
|
|
|
import fcntl
|
2026-02-06 15:35:06 +08:00
|
|
|
import os
|
2026-01-26 15:03:18 +08:00
|
|
|
import subprocess
|
|
|
|
|
|
|
|
|
|
ascend_home_path = os.environ["ASCEND_HOME_PATH"]
|
2026-02-06 15:35:06 +08:00
|
|
|
static_kernel_dir_path = os.path.join(ascend_home_path, "opp/static_kernel")
|
|
|
|
|
uninstall_script_path = os.path.join(static_kernel_dir_path, "ai_core/uninstall.sh")
|
|
|
|
|
lock_file_path = os.path.join(static_kernel_dir_path, "uninstall.lock")
|
2026-01-26 15:03:18 +08:00
|
|
|
|
|
|
|
|
if not os.path.exists(uninstall_script_path):
|
|
|
|
|
return
|
2026-02-06 15:35:06 +08:00
|
|
|
with open(lock_file_path, "w") as lock_fd:
|
2026-01-26 15:03:18 +08:00
|
|
|
try:
|
|
|
|
|
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
|
|
|
subprocess.Popen(
|
2026-02-06 15:35:06 +08:00
|
|
|
["bash", uninstall_script_path],
|
2026-01-26 15:03:18 +08:00
|
|
|
stdin=subprocess.DEVNULL,
|
|
|
|
|
stdout=subprocess.DEVNULL,
|
|
|
|
|
stderr=subprocess.DEVNULL,
|
2026-02-06 15:35:06 +08:00
|
|
|
start_new_session=True,
|
2026-01-26 15:03:18 +08:00
|
|
|
)
|
2026-02-06 15:35:06 +08:00
|
|
|
except (BlockingIOError, OSError):
|
2026-01-26 15:03:18 +08:00
|
|
|
return
|
|
|
|
|
finally:
|
|
|
|
|
try:
|
|
|
|
|
fcntl.flock(lock_fd, fcntl.LOCK_UN)
|
|
|
|
|
if os.path.exists(lock_file_path):
|
|
|
|
|
os.remove(lock_file_path)
|
|
|
|
|
except Exception:
|
|
|
|
|
return
|
|
|
|
|
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
def sleep(self, level: int = 1) -> None:
|
2026-01-07 09:25:55 +08:00
|
|
|
free_bytes_before_sleep = torch.npu.mem_get_info()[0]
|
2025-09-18 19:51:52 +08:00
|
|
|
# Save the buffers before level 2 sleep
|
|
|
|
|
if level == 2:
|
|
|
|
|
model = self.model_runner.model
|
2026-02-06 15:35:06 +08:00
|
|
|
self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}
|
2025-06-06 21:54:02 +08:00
|
|
|
allocator = CaMemAllocator.get_instance()
|
2026-02-06 15:35:06 +08:00
|
|
|
allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
|
2026-01-07 09:25:55 +08:00
|
|
|
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
2025-06-06 21:54:02 +08:00
|
|
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
|
|
|
|
used_bytes = total - free_bytes_after_sleep
|
|
|
|
|
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
|
|
|
|
logger.info(
|
2026-02-06 15:35:06 +08:00
|
|
|
"Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
|
|
|
|
|
freed_bytes / GiB_bytes,
|
|
|
|
|
used_bytes / GiB_bytes,
|
|
|
|
|
)
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
def wake_up(self, tags: list[str] | None = None) -> None:
|
2025-12-19 14:27:24 +08:00
|
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_NZ:
|
2025-11-08 14:11:55 +08:00
|
|
|
raise ValueError(
|
|
|
|
|
"FRACTAL_NZ mode is enabled. This may cause model parameter precision issues "
|
2026-02-06 15:35:06 +08:00
|
|
|
"in the RL scenarios. Please set VLLM_ASCEND_ENABLE_NZ=0."
|
|
|
|
|
)
|
2025-06-06 21:54:02 +08:00
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
allocator.wake_up(tags=tags)
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
|
2026-01-06 16:41:39 +08:00
|
|
|
hidden_size = self.vllm_config.model_config.hf_text_config.hidden_size
|
2025-12-08 20:34:52 +08:00
|
|
|
model = self.model_runner.model
|
2026-01-05 09:17:26 +08:00
|
|
|
if tags is None or "weights" in tags:
|
|
|
|
|
for name, param in model.named_parameters():
|
2026-02-06 15:35:06 +08:00
|
|
|
if "w2_weight" in name and param.shape[2] == hidden_size:
|
|
|
|
|
parts = name.split(".")
|
2026-01-05 09:17:26 +08:00
|
|
|
param_name = parts[-1]
|
|
|
|
|
parent_module = model.get_submodule(".".join(parts[:-1]))
|
|
|
|
|
|
|
|
|
|
w2_data = param.transpose(1, 2)
|
|
|
|
|
w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
|
|
|
|
|
setattr(parent_module, param_name, w2_data)
|
2026-02-06 15:35:06 +08:00
|
|
|
elif "w13_weight" in name and param.shape[1] == hidden_size:
|
|
|
|
|
parts = name.split(".")
|
2026-01-05 09:17:26 +08:00
|
|
|
param_name = parts[-1]
|
|
|
|
|
parent_module = model.get_submodule(".".join(parts[:-1]))
|
|
|
|
|
|
|
|
|
|
w13_data = param.transpose(1, 2)
|
2026-02-06 15:35:06 +08:00
|
|
|
w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
|
2026-01-05 09:17:26 +08:00
|
|
|
setattr(parent_module, param_name, w13_data)
|
2025-12-08 20:34:52 +08:00
|
|
|
|
2025-09-18 19:51:52 +08:00
|
|
|
# Restore the buffers after level 2 sleep
|
|
|
|
|
if len(self._sleep_saved_buffers):
|
|
|
|
|
for name, buffer in model.named_buffers():
|
|
|
|
|
if name in self._sleep_saved_buffers:
|
|
|
|
|
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
|
|
|
|
self._sleep_saved_buffers = {}
|
|
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
2025-06-16 21:03:16 +08:00
|
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
|
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
|
|
2025-08-05 18:43:04 +08:00
|
|
|
def _init_device(self):
|
2025-06-25 16:20:14 +08:00
|
|
|
device = torch.device(f"npu:{self.local_rank}")
|
2026-01-07 09:25:55 +08:00
|
|
|
torch.npu.set_device(device)
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
|
2026-03-12 15:59:03 +08:00
|
|
|
# Import _inductor for graph mode execution with triton
|
|
|
|
|
# This lazy import avoids torch_npu re-initialization in patch
|
|
|
|
|
# Note that this should be imported after torch.npu.set_device
|
|
|
|
|
# to avoid repeated set_device in extra processes
|
|
|
|
|
from vllm.triton_utils import HAS_TRITON
|
|
|
|
|
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
|
import torch_npu._inductor # noqa: F401
|
|
|
|
|
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
gc.collect()
|
2026-01-07 09:25:55 +08:00
|
|
|
torch.npu.empty_cache()
|
2025-11-26 11:48:58 +08:00
|
|
|
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
# take current memory snapshot
|
|
|
|
|
self.init_snapshot = MemorySnapshot()
|
|
|
|
|
self.requested_memory = self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization
|
2026-04-21 03:05:32 +00:00
|
|
|
if (
|
|
|
|
|
self.init_snapshot.free_memory < self.requested_memory
|
|
|
|
|
and not envs_ascend.VLLM_ASCEND_ENABLE_VNPU
|
|
|
|
|
):
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
GiB = lambda b: round(b / GiB_bytes, 2)
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Free memory on device "
|
|
|
|
|
f"({GiB(self.init_snapshot.free_memory)}/"
|
|
|
|
|
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
|
|
|
|
|
f"is less than desired GPU memory utilization "
|
|
|
|
|
f"({self.cache_config.gpu_memory_utilization}, "
|
|
|
|
|
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
|
|
|
|
|
f"utilization or reduce GPU memory used by other processes."
|
|
|
|
|
)
|
|
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
if (
|
|
|
|
|
self.parallel_config.data_parallel_size > 1
|
|
|
|
|
and self.parallel_config.data_parallel_size_local > 0
|
|
|
|
|
and self.parallel_config.distributed_executor_backend not in ["ray", "external_launcher"]
|
|
|
|
|
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
|
|
|
|
and self.vllm_config.parallel_config.nnodes_within_dp == 1
|
|
|
|
|
):
|
|
|
|
|
visible_device_count = torch.npu.device_count() if torch.npu.is_available() else 0
|
2025-11-27 21:18:32 +08:00
|
|
|
assert self.parallel_config.local_world_size <= visible_device_count, (
|
|
|
|
|
f"local_world_size ({self.parallel_config.local_world_size}) must "
|
|
|
|
|
f"be less than or equal to the number of visible devices "
|
2026-02-06 15:35:06 +08:00
|
|
|
f"({visible_device_count})."
|
|
|
|
|
)
|
2025-11-26 11:48:58 +08:00
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
# Initialize the distributed environment.
|
2025-04-15 10:24:02 +08:00
|
|
|
self._init_worker_distributed_environment()
|
2025-03-20 19:34:44 +08:00
|
|
|
# Set random seed.
|
2026-01-07 09:25:55 +08:00
|
|
|
set_random_seed(self.model_config.seed)
|
2025-12-02 17:10:19 +08:00
|
|
|
# Initialize device properties used by triton kernels.
|
|
|
|
|
init_device_properties_triton()
|
2026-01-13 09:21:28 +08:00
|
|
|
|
|
|
|
|
# binding cpu
|
|
|
|
|
if get_ascend_config().enable_cpu_binding:
|
|
|
|
|
try:
|
|
|
|
|
bind_cpus(self.local_rank)
|
|
|
|
|
except Exception as e:
|
2026-02-06 15:35:06 +08:00
|
|
|
logger.warning(f"Bind cpus failed in rank{self.local_rank}: {e} Skip binding cpu.")
|
2025-08-05 18:43:04 +08:00
|
|
|
return device
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-08-05 18:43:04 +08:00
|
|
|
def init_device(self):
|
[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
1. refactor deepseek to the latest code arch as of
https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
|
|
|
# NOTE: KEEP device the member of `NPUWorker`, as it will be checked
|
|
|
|
|
# in ray scenario. see https://github.com/vllm-project/vllm/pull/26845
|
|
|
|
|
# for more details
|
|
|
|
|
self.device = self._init_device()
|
2025-12-30 08:32:14 +08:00
|
|
|
# Initialize workspace manager
|
|
|
|
|
num_ubatches = 1
|
|
|
|
|
init_workspace_manager(self.device, num_ubatches)
|
2025-03-20 19:34:44 +08:00
|
|
|
# Init ModelRunner here, so that we have access to self.device.
|
2025-12-18 15:51:54 +08:00
|
|
|
if self.use_v2_model_runner:
|
2026-02-06 15:35:06 +08:00
|
|
|
logger.warning("npu model runner v2 is in developing, some features doesn't work for now.")
|
|
|
|
|
from vllm_ascend.worker.v2.model_runner import NPUModelRunner as NPUModelRunnerV2
|
|
|
|
|
|
2025-12-18 15:51:54 +08:00
|
|
|
self.model_runner = NPUModelRunnerV2(self.vllm_config, self.device)
|
|
|
|
|
else:
|
|
|
|
|
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-12-17 23:51:31 +08:00
|
|
|
@torch.inference_mode()
|
2025-03-20 19:34:44 +08:00
|
|
|
def determine_available_memory(self) -> int:
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
"""Profiles the peak memory usage of the model to determine how much
|
|
|
|
|
memory can be used for KV cache without OOMs.
|
|
|
|
|
|
|
|
|
|
The engine will first conduct a profiling of the existing memory usage.
|
|
|
|
|
Then, it calculates the free memory that can be used for KV cache in
|
|
|
|
|
bytes.
|
|
|
|
|
"""
|
|
|
|
|
GiB = lambda b: b / GiB_bytes
|
2026-04-24 08:31:54 +00:00
|
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
free, total = allocator.get_pool_mem_info()
|
|
|
|
|
if self.cache_config.gpu_memory_utilization <= 0.9:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"GPU memory utilization is set to %.2f. For VNPU mode, it is recommended to set gpu_memory_utilization to a larger value",
|
|
|
|
|
self.cache_config.gpu_memory_utilization,
|
|
|
|
|
)
|
|
|
|
|
available_kv_cache_memory = int(
|
|
|
|
|
total * self.cache_config.gpu_memory_utilization - (total - free)
|
|
|
|
|
)
|
|
|
|
|
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
|
|
|
|
self.available_kv_cache_memory_bytes = available_kv_cache_memory
|
|
|
|
|
logger.info_once(
|
|
|
|
|
"Available KV cache memory: %.2f GiB",
|
|
|
|
|
GiB(self.available_kv_cache_memory_bytes),
|
|
|
|
|
scope="local",
|
|
|
|
|
)
|
|
|
|
|
return int(self.available_kv_cache_memory_bytes)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
|
|
|
|
# of the model.
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
with memory_profiling(
|
|
|
|
|
self.init_snapshot,
|
|
|
|
|
weights_memory=int(self.model_runner.model_memory_usage),
|
|
|
|
|
) as profile_result:
|
|
|
|
|
self.model_runner.profile_run()
|
|
|
|
|
|
|
|
|
|
free_gpu_memory = profile_result.after_profile.free_memory
|
|
|
|
|
assert self.init_snapshot.free_memory > free_gpu_memory, (
|
2025-03-20 19:34:44 +08:00
|
|
|
"Error in memory profiling. "
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
|
|
|
|
|
f"current free memory {GiB(free_gpu_memory)} GiB. "
|
|
|
|
|
"This happens when other processes sharing the same container "
|
|
|
|
|
"release GPU memory while vLLM is profiling during initialization. "
|
|
|
|
|
"To fix this, ensure consistent GPU memory allocation or "
|
|
|
|
|
"isolate vLLM in its own container."
|
|
|
|
|
)
|
2026-03-23 14:22:59 +08:00
|
|
|
self.available_kv_cache_memory_bytes = self.requested_memory - profile_result.non_kv_cache_memory
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
logger.debug(profile_result)
|
|
|
|
|
logger.info_once(
|
2026-03-23 14:22:59 +08:00
|
|
|
"Available KV cache memory: %.2f GiB", GiB(self.available_kv_cache_memory_bytes), scope="local"
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
)
|
2026-03-23 14:22:59 +08:00
|
|
|
|
[Refactor][Bugfix] Use upstream `mem_utils` for profiling and correct non-torch memory recorded during profiling (#6625)
### What this PR does / why we need it?
1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.
---
**More details about point 2:**
After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.
With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.
Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.
Resolution:
We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).
---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.
After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
2026-02-25 14:28:08 +08:00
|
|
|
return int(self.available_kv_cache_memory_bytes)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def execute_model(
|
|
|
|
|
self,
|
|
|
|
|
scheduler_output: "SchedulerOutput",
|
2026-01-06 08:44:29 +08:00
|
|
|
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
2025-09-25 14:15:02 +08:00
|
|
|
# enable msMonitor to monitor the performance of vllm-ascend
|
|
|
|
|
if envs_ascend.MSMONITOR_USE_DAEMON:
|
|
|
|
|
dp.step()
|
|
|
|
|
|
2026-03-15 09:45:09 +08:00
|
|
|
if self._pp_send_work:
|
|
|
|
|
for handle in self._pp_send_work:
|
|
|
|
|
handle.wait()
|
|
|
|
|
self._pp_send_work = []
|
|
|
|
|
|
2025-07-11 15:30:51 +08:00
|
|
|
intermediate_tensors = None
|
2025-09-19 11:29:50 +08:00
|
|
|
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
|
|
|
|
if forward_pass and not get_pp_group().is_first_rank:
|
2026-02-06 15:35:06 +08:00
|
|
|
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
|
|
|
|
# it will conflict with the all-gather operation in flashcomm1.
|
2026-03-01 20:22:50 +08:00
|
|
|
if enable_sp():
|
mooncake connector support pipeline parallel & fix pp with flashcomm1 (#4054)
### What this PR does / why we need it?
To support pipeline parallel with PD disaggregation, this PR support PP
in mooncake connector and fix other bugs when enable pp with other
optimization params, including following changes:
- mooncake connector support pp in prefill, we do not support decode pp
currently
- fix bugs when enable both pp and flashcomm1
- optimize ascend-scheduler to support full batch in multiple pipeline
stages, original implementation would cause all pipeline stages
batch_size total summed to max_num_seq, which makes pipeline is not
full, this optimization can make all stages running with full batch_size
= max_num_seq, the same changes will contribute to vllm scheduler too.
### Does this PR introduce _any_ user-facing change?
add `pp_size` in mooncake connector kv_connector_extra_config
```
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 1,
"tp_size": 4,
"pp_size": 4
},
"decode": {
"dp_size": 16,
"tp_size": 1
}
}
```
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Signed-off-by: 秋刀鱼 <jaychou1620@Gmail.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zss <zss@qq.com>
Co-authored-by: zss <3265779424@qq.com>
2025-12-10 16:01:43 +08:00
|
|
|
all_gather_group = None
|
|
|
|
|
else:
|
|
|
|
|
all_gather_group = get_tp_group()
|
2026-03-15 09:45:09 +08:00
|
|
|
tensor_dict, comm_handles, comm_postprocess = get_pp_group().irecv_tensor_dict(
|
|
|
|
|
all_gather_group=all_gather_group
|
|
|
|
|
)
|
|
|
|
|
assert tensor_dict is not None
|
|
|
|
|
intermediate_tensors = AsyncIntermediateTensors(
|
|
|
|
|
tensor_dict,
|
|
|
|
|
comm_handles=comm_handles,
|
|
|
|
|
comm_postprocess=comm_postprocess,
|
2026-02-06 15:35:06 +08:00
|
|
|
)
|
2025-07-11 15:30:51 +08:00
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
|
|
|
|
|
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput, NoneType)):
|
2025-09-19 11:29:50 +08:00
|
|
|
return output
|
2025-08-04 10:08:58 +08:00
|
|
|
|
2025-09-19 11:29:50 +08:00
|
|
|
assert isinstance(output, IntermediateTensors)
|
|
|
|
|
parallel_config = self.vllm_config.parallel_config
|
2026-02-06 15:35:06 +08:00
|
|
|
assert parallel_config.distributed_executor_backend != ("external_launcher") and not get_pp_group().is_last_rank
|
|
|
|
|
# If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise
|
|
|
|
|
# it will conflict with the all-gather operation in flashcomm1.
|
2026-03-01 20:22:50 +08:00
|
|
|
if enable_sp():
|
mooncake connector support pipeline parallel & fix pp with flashcomm1 (#4054)
### What this PR does / why we need it?
To support pipeline parallel with PD disaggregation, this PR support PP
in mooncake connector and fix other bugs when enable pp with other
optimization params, including following changes:
- mooncake connector support pp in prefill, we do not support decode pp
currently
- fix bugs when enable both pp and flashcomm1
- optimize ascend-scheduler to support full batch in multiple pipeline
stages, original implementation would cause all pipeline stages
batch_size total summed to max_num_seq, which makes pipeline is not
full, this optimization can make all stages running with full batch_size
= max_num_seq, the same changes will contribute to vllm scheduler too.
### Does this PR introduce _any_ user-facing change?
add `pp_size` in mooncake connector kv_connector_extra_config
```
"kv_connector_extra_config": {
"use_ascend_direct": true,
"prefill": {
"dp_size": 1,
"tp_size": 4,
"pp_size": 4
},
"decode": {
"dp_size": 16,
"tp_size": 1
}
}
```
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9
---------
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
Signed-off-by: Kurumi5210 <jaychou1620@gmail.com>
Signed-off-by: 秋刀鱼 <jaychou1620@Gmail.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: zss <zss@qq.com>
Co-authored-by: zss <3265779424@qq.com>
2025-12-10 16:01:43 +08:00
|
|
|
all_gather_group = None
|
|
|
|
|
else:
|
|
|
|
|
all_gather_group = get_tp_group()
|
2026-03-15 09:45:09 +08:00
|
|
|
self._pp_send_work = get_pp_group().isend_tensor_dict(
|
|
|
|
|
output.tensors,
|
|
|
|
|
all_gather_group=all_gather_group,
|
|
|
|
|
)
|
2025-08-04 21:37:50 +08:00
|
|
|
|
2025-09-19 11:29:50 +08:00
|
|
|
kv_connector_output = output.kv_connector_output
|
|
|
|
|
if not kv_connector_output:
|
|
|
|
|
return None
|
2025-08-04 10:08:58 +08:00
|
|
|
|
2025-09-19 11:29:50 +08:00
|
|
|
# In case of PP with kv transfer, we need to pass through the
|
|
|
|
|
# kv_connector_output
|
2026-02-06 15:35:06 +08:00
|
|
|
if not kv_connector_output.finished_sending and not kv_connector_output.finished_recving:
|
2025-09-19 11:29:50 +08:00
|
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
|
|
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
|
|
|
output.kv_connector_output = kv_connector_output
|
2025-08-04 10:08:58 +08:00
|
|
|
return output
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-11-26 11:48:58 +08:00
|
|
|
@torch.inference_mode()
|
2026-02-06 15:35:06 +08:00
|
|
|
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput | AsyncModelRunnerOutput:
|
2025-11-26 11:48:58 +08:00
|
|
|
return self.model_runner.sample_tokens(grammar_output)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
def load_model(self) -> None:
|
2025-06-06 21:54:02 +08:00
|
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
2026-02-06 15:35:06 +08:00
|
|
|
assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
|
2025-06-06 21:54:02 +08:00
|
|
|
context = allocator.use_memory_pool(tag="weights")
|
2026-04-21 03:05:32 +00:00
|
|
|
elif envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
assert (
|
|
|
|
|
allocator.get_current_usage() == 0
|
|
|
|
|
), "vNPU mode can only be used for one instance per process."
|
|
|
|
|
context = allocator.use_memory_pool(tag="weights")
|
2025-06-06 21:54:02 +08:00
|
|
|
else:
|
|
|
|
|
from contextlib import nullcontext
|
2026-02-06 15:35:06 +08:00
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
context = nullcontext() # type: ignore
|
2025-12-29 22:48:05 +08:00
|
|
|
|
|
|
|
|
with context, set_current_vllm_config(self.vllm_config):
|
2025-06-06 21:54:02 +08:00
|
|
|
self.model_runner.load_model()
|
2026-04-21 03:05:32 +00:00
|
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
|
|
|
|
# save memory to host with lock
|
|
|
|
|
self.offload_vram()
|
|
|
|
|
succ, _ = self.try_reload_vram()
|
|
|
|
|
assert succ, "Failed to reload model weights after offloading."
|
|
|
|
|
|
|
|
|
|
def offload_vram(self) -> None:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
allocator.offload_vram(offload_tags=("weights",))
|
|
|
|
|
|
|
|
|
|
def try_reload_vram(self) -> tuple[bool, bool]:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
return allocator.try_reload_vram(tags=None)
|
|
|
|
|
|
|
|
|
|
def vnpu_unlock_gpu(self) -> None:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
allocator.vnpu_unlock_gpu()
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2026-03-06 09:08:52 +08:00
|
|
|
def compile_or_warm_up_model(self) -> float:
|
2025-08-20 09:01:04 +08:00
|
|
|
# Note: need to adapt for graph mode.
|
2026-02-06 15:35:06 +08:00
|
|
|
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy()
|
2026-01-06 21:55:47 +08:00
|
|
|
if not self.model_config.enforce_eager:
|
2026-01-07 18:42:55 +08:00
|
|
|
cg_capture_sizes: list[int] = []
|
|
|
|
|
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
|
|
|
|
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
|
|
|
|
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
|
2026-02-06 15:35:06 +08:00
|
|
|
warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes]
|
2026-01-07 18:42:55 +08:00
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges()
|
2026-01-07 18:42:55 +08:00
|
|
|
# For each compile_range, if none of the batch sizes
|
|
|
|
|
# in warmup_sizes or cudagraph_capture_sizes are in the range,
|
|
|
|
|
# add the end of the range to ensure compilation/warmup.
|
|
|
|
|
all_sizes = set(cg_capture_sizes)
|
|
|
|
|
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
|
|
|
|
|
for compile_range in compile_ranges:
|
|
|
|
|
if not any(x in compile_range for x in all_sizes):
|
|
|
|
|
warmup_sizes.append(compile_range.end)
|
|
|
|
|
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
for size in sorted(warmup_sizes, reverse=True):
|
|
|
|
|
logger.info("Compile and warming up model for size %d", size)
|
|
|
|
|
self.model_runner._dummy_run(size)
|
|
|
|
|
if not self.model_config.enforce_eager:
|
|
|
|
|
self.model_runner.capture_model()
|
2025-09-10 14:06:38 +08:00
|
|
|
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
|
|
|
|
|
# may cause performance degradation at runtime.
|
2025-12-17 14:08:19 +08:00
|
|
|
if get_ascend_device_type() != AscendDeviceType.A5:
|
2025-12-12 15:50:57 +08:00
|
|
|
self._warm_up_atb()
|
2025-03-20 19:34:44 +08:00
|
|
|
# Reset the seed to ensure that the random state is not affected by
|
|
|
|
|
# the model initialization and profiling.
|
2026-01-07 09:25:55 +08:00
|
|
|
set_random_seed(self.model_config.seed)
|
2026-03-06 09:08:52 +08:00
|
|
|
return self.vllm_config.compilation_config.compilation_time
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-09-10 14:06:38 +08:00
|
|
|
def _warm_up_atb(self):
|
|
|
|
|
x = torch.rand((2, 4), dtype=torch.float16).npu()
|
|
|
|
|
weight = torch.rand((2, 4), dtype=torch.float16).npu()
|
|
|
|
|
c = torch.rand((4, 4), dtype=torch.float32).npu()
|
|
|
|
|
torch_npu._npu_matmul_add_fp32(x, weight, c)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
def get_model(self) -> nn.Module:
|
|
|
|
|
return self.model_runner.get_model()
|
|
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
def get_kv_connector_handshake_metadata(self) -> dict | None:
|
2025-12-17 09:28:03 +08:00
|
|
|
"""Get KV connector metadata from this worker if available."""
|
|
|
|
|
if not has_kv_transfer_group():
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
connector = get_kv_transfer_group()
|
|
|
|
|
|
|
|
|
|
# Return None for connectors that don't need to exchange handshake
|
|
|
|
|
# metadata across workers.
|
|
|
|
|
if (metadata := connector.get_handshake_metadata()) is None:
|
|
|
|
|
return None
|
|
|
|
|
return {self.rank: metadata}
|
2025-11-29 16:09:45 +08:00
|
|
|
|
2025-03-28 19:34:23 +08:00
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
2025-03-20 19:34:44 +08:00
|
|
|
return self.model_runner.get_kv_cache_spec()
|
|
|
|
|
|
2026-01-26 09:03:33 +08:00
|
|
|
def update_max_model_len(self, max_model_len: int) -> None:
|
|
|
|
|
"""Update max_model_len after auto-fit to NPU memory.
|
|
|
|
|
|
|
|
|
|
This is called when max_model_len=-1 is used and the engine
|
|
|
|
|
automatically determines the maximum context length that fits
|
|
|
|
|
in GPU memory. Workers need to update their cached max_model_len
|
|
|
|
|
to match the engine's decision.
|
|
|
|
|
"""
|
|
|
|
|
self.model_config.max_model_len = max_model_len
|
|
|
|
|
if self.model_runner is not None:
|
|
|
|
|
self.model_runner.update_max_model_len(max_model_len)
|
|
|
|
|
logger.debug("Updated max_model_len to %d", max_model_len)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
|
|
|
"""Allocate NPU KV cache with the specified kv_cache_config."""
|
2026-03-09 10:49:04 +08:00
|
|
|
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
|
2025-06-06 21:54:02 +08:00
|
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
2026-04-21 03:05:32 +00:00
|
|
|
elif envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
2025-06-06 21:54:02 +08:00
|
|
|
else:
|
|
|
|
|
from contextlib import nullcontext
|
2026-02-06 15:35:06 +08:00
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
context = nullcontext() # type: ignore
|
|
|
|
|
with context:
|
|
|
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2026-03-05 16:18:34 +08:00
|
|
|
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
|
|
|
|
|
# Check if profiling is enabled (RFC #6954 - align with upstream vLLM)
|
|
|
|
|
if self.profiler_config is None or self.profiler_config.profiler is None:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Profiling is not enabled. Please set --profiler-config to enable "
|
|
|
|
|
"profiling. Example: "
|
|
|
|
|
"'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
|
|
|
|
|
"=YOUR_DIR_PATH_TO_DUMP_TRACE'"
|
|
|
|
|
)
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
if is_start:
|
2026-03-05 16:18:34 +08:00
|
|
|
from vllm.distributed.utils import get_worker_rank_suffix
|
|
|
|
|
|
|
|
|
|
rank_suffix = get_worker_rank_suffix(global_rank=self.rank)
|
|
|
|
|
trace_name = f"{profile_prefix}_{rank_suffix}" if profile_prefix else rank_suffix
|
|
|
|
|
|
|
|
|
|
if self.profiler is None:
|
|
|
|
|
self.profiler = self._create_profiler(trace_name)
|
|
|
|
|
logger.debug("Starting torch profiler with trace name: %s", trace_name)
|
|
|
|
|
self.profiler.start() # type: ignore[attr-defined]
|
|
|
|
|
else:
|
|
|
|
|
# Profiler already initialized. Restart profiling but keep
|
|
|
|
|
# the original trace name from the first initialization.
|
|
|
|
|
self.profiler.start()
|
2025-03-20 19:34:44 +08:00
|
|
|
else:
|
2026-03-05 16:18:34 +08:00
|
|
|
if self.profiler is None:
|
|
|
|
|
logger.warning("Profiler was not started, nothing to stop.")
|
|
|
|
|
return
|
2025-03-20 19:34:44 +08:00
|
|
|
self.profiler.stop()
|
|
|
|
|
|
2025-05-22 19:20:51 +08:00
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
|
|
|
return self.model_runner.add_lora(lora_request)
|
|
|
|
|
|
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
|
|
|
return self.model_runner.remove_lora(lora_id)
|
|
|
|
|
|
|
|
|
|
def list_loras(self) -> set[int]:
|
|
|
|
|
return self.model_runner.list_loras()
|
|
|
|
|
|
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
|
|
|
return self.model_runner.pin_lora(lora_id)
|
|
|
|
|
|
2026-02-27 16:05:21 +08:00
|
|
|
def reset_encoder_cache(self) -> None:
|
|
|
|
|
self.model_runner.reset_encoder_cache()
|
|
|
|
|
|
2025-07-21 11:50:46 +08:00
|
|
|
def execute_dummy_batch(self) -> None:
|
2026-02-06 15:35:06 +08:00
|
|
|
self.model_runner._dummy_run(num_tokens=self.model_runner.decode_token_per_req, uniform_decode=True)
|
2025-05-12 17:31:29 +08:00
|
|
|
|
2025-04-15 10:24:02 +08:00
|
|
|
def _init_worker_distributed_environment(self) -> None:
|
|
|
|
|
"""Initialize the distributed environment."""
|
2026-01-07 09:11:26 +08:00
|
|
|
init_batch_invariance()
|
2026-02-06 15:35:06 +08:00
|
|
|
init_distributed_environment(
|
|
|
|
|
self.parallel_config.world_size, self.rank, self.distributed_init_method, self.local_rank, "hccl"
|
|
|
|
|
)
|
2025-12-05 10:31:49 +08:00
|
|
|
ensure_model_parallel_initialized(
|
|
|
|
|
self.parallel_config.tensor_parallel_size,
|
|
|
|
|
self.parallel_config.pipeline_parallel_size,
|
|
|
|
|
self.parallel_config.prefill_context_parallel_size,
|
2026-02-06 15:35:06 +08:00
|
|
|
self.parallel_config.decode_context_parallel_size,
|
|
|
|
|
)
|
2025-07-28 14:06:20 +08:00
|
|
|
init_ascend_model_parallel(self.parallel_config)
|
2025-12-03 20:48:45 +08:00
|
|
|
ensure_ec_transfer_initialized(self.vllm_config)
|
2025-04-15 10:24:02 +08:00
|
|
|
|
2026-03-05 16:18:34 +08:00
|
|
|
def _create_profiler(self, trace_name: str):
|
|
|
|
|
"""Create torch_npu profiler with trace naming for unique files per worker (RFC #6954)."""
|
|
|
|
|
profiler_config = self.profiler_config
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2026-03-05 16:18:34 +08:00
|
|
|
if profiler_config.profiler != "torch":
|
|
|
|
|
raise RuntimeError(f"Unrecognized profiler: {profiler_config.profiler}")
|
|
|
|
|
if not profiler_config.torch_profiler_dir:
|
|
|
|
|
raise RuntimeError("torch_profiler_dir cannot be empty.")
|
|
|
|
|
if envs_ascend.MSMONITOR_USE_DAEMON:
|
|
|
|
|
raise RuntimeError("MSMONITOR_USE_DAEMON and torch profiler cannot be both enabled at the same time.")
|
|
|
|
|
|
|
|
|
|
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
|
|
|
|
export_type=torch_npu.profiler.ExportType.Text,
|
|
|
|
|
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
|
|
|
|
|
msprof_tx=False,
|
2026-03-27 16:37:54 +08:00
|
|
|
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
|
2026-03-05 16:18:34 +08:00
|
|
|
l2_cache=False,
|
|
|
|
|
op_attr=False,
|
|
|
|
|
data_simplification=True,
|
|
|
|
|
record_op_args=False,
|
|
|
|
|
gc_detect_threshold=None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return torch_npu.profiler.profile(
|
|
|
|
|
activities=[
|
|
|
|
|
torch_npu.profiler.ProfilerActivity.CPU,
|
|
|
|
|
torch_npu.profiler.ProfilerActivity.NPU,
|
|
|
|
|
],
|
|
|
|
|
with_stack=False,
|
|
|
|
|
profile_memory=profiler_config.torch_profiler_with_memory,
|
|
|
|
|
# NOTE: torch_npu.profiler.with_modules is equivalent to torch.profiler.with_stack.
|
|
|
|
|
# The with_stack option in torch_npu.profiler introduces significant time overhead.
|
|
|
|
|
with_modules=profiler_config.torch_profiler_with_stack,
|
|
|
|
|
experimental_config=experimental_config,
|
|
|
|
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
|
|
|
|
profiler_config.torch_profiler_dir,
|
|
|
|
|
worker_name=trace_name,
|
|
|
|
|
),
|
|
|
|
|
)
|
2025-07-20 02:11:57 +08:00
|
|
|
|
|
|
|
|
def get_supported_pooling_tasks(self):
|
|
|
|
|
return self.model_runner.get_supported_pooling_tasks()
|
2025-07-26 08:20:21 +08:00
|
|
|
|
|
|
|
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
|
|
|
|
return self.model_runner.get_supported_tasks()
|
2025-08-22 17:09:08 +08:00
|
|
|
|
2026-02-06 15:35:06 +08:00
|
|
|
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
2025-09-03 10:58:08 +08:00
|
|
|
return self.model_runner.take_draft_token_ids()
|
2026-02-11 15:24:48 +08:00
|
|
|
|
|
|
|
|
def check_health(self) -> None:
|
|
|
|
|
import subprocess
|
|
|
|
|
|
|
|
|
|
logger.info("check_health Start!")
|
|
|
|
|
try:
|
|
|
|
|
result = subprocess.run(
|
|
|
|
|
["npu-smi", "info", "-i", str(self.local_rank), "-t", "health"],
|
|
|
|
|
capture_output=True,
|
|
|
|
|
text=True,
|
|
|
|
|
timeout=10,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if result.returncode == 0:
|
|
|
|
|
parse_text_output(result.stdout)
|
|
|
|
|
logger.info("check_health success!")
|
|
|
|
|
else:
|
|
|
|
|
logger.info(f"query NPU card {self.local_rank} fail: {result.stderr}")
|
|
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
|
logger.info(f"query NPU card {self.local_rank} timeout.")
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
logger.info("npu-smi tool not found.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.info(f"query NPU card {self.local_rank} fail: {e}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_text_output(output) -> None:
|
|
|
|
|
lines = output.strip().split("\n")
|
|
|
|
|
for i, line in enumerate(lines):
|
|
|
|
|
line = line.strip()
|
|
|
|
|
if "Health" in line:
|
|
|
|
|
if line.split(":")[-1].strip() != "OK":
|
|
|
|
|
raise RuntimeError("NPU card health status is not OK")
|
|
|
|
|
return
|