2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
|
2025-03-20 19:34:44 +08:00
|
|
|
#
|
|
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
from typing import Optional
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch_npu
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
2025-03-20 19:34:44 +08:00
|
|
|
from vllm import envs
|
2025-04-15 10:24:02 +08:00
|
|
|
from vllm.config import VllmConfig
|
2025-04-18 12:23:32 +08:00
|
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
2025-06-25 16:20:14 +08:00
|
|
|
init_distributed_environment)
|
2025-04-30 09:15:50 +08:00
|
|
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
2025-07-11 15:30:51 +08:00
|
|
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
2025-04-15 10:18:05 +08:00
|
|
|
from vllm.logger import logger
|
2025-05-22 19:20:51 +08:00
|
|
|
from vllm.lora.request import LoRARequest
|
2025-07-11 15:30:51 +08:00
|
|
|
from vllm.sequence import IntermediateTensors
|
2025-06-06 21:54:02 +08:00
|
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
2025-03-21 15:43:43 +08:00
|
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
2025-06-06 21:54:02 +08:00
|
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
2025-03-20 19:34:44 +08:00
|
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
|
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
|
|
|
|
|
2025-07-07 22:37:14 +08:00
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
2025-06-06 21:54:02 +08:00
|
|
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
2025-04-19 17:38:18 +08:00
|
|
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
2025-04-15 10:24:02 +08:00
|
|
|
from vllm_ascend.platform import NPUPlatform
|
2025-07-07 22:37:14 +08:00
|
|
|
from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
|
|
|
|
|
check_torchair_cache_exist,
|
|
|
|
|
delete_torchair_cache_file,
|
|
|
|
|
read_kv_cache_bytes_from_file,
|
|
|
|
|
sleep_mode_enabled, try_register_lib)
|
2025-03-20 19:34:44 +08:00
|
|
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NPUWorker(WorkerBase):
|
|
|
|
|
|
2025-04-15 10:24:02 +08:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
local_rank: int,
|
|
|
|
|
rank: int,
|
|
|
|
|
distributed_init_method: str,
|
|
|
|
|
is_driver_worker: bool = False,
|
|
|
|
|
# Additional parameters for compatibility with vllm
|
|
|
|
|
**kwargs):
|
|
|
|
|
"""Initialize the worker for Ascend."""
|
2025-04-16 09:28:58 +08:00
|
|
|
# register patch for vllm
|
|
|
|
|
from vllm_ascend.utils import adapt_patch
|
|
|
|
|
adapt_patch()
|
2025-03-20 19:34:44 +08:00
|
|
|
# Register ops when worker init.
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
from vllm_ascend import ops
|
|
|
|
|
ops.register_dummy_fusion_op()
|
|
|
|
|
_register_atb_extensions()
|
2025-06-05 16:28:01 +08:00
|
|
|
# init ascend config
|
|
|
|
|
init_ascend_config(vllm_config)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
super().__init__(vllm_config=vllm_config,
|
|
|
|
|
local_rank=local_rank,
|
|
|
|
|
rank=rank,
|
|
|
|
|
distributed_init_method=distributed_init_method,
|
|
|
|
|
is_driver_worker=is_driver_worker)
|
2025-06-09 14:08:18 +08:00
|
|
|
|
2025-04-17 19:31:50 +08:00
|
|
|
# Try to import mindie_turbo to accelerate vLLM inference.
|
|
|
|
|
try_register_lib(
|
|
|
|
|
"mindie_turbo",
|
|
|
|
|
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
|
|
|
|
)
|
2025-03-20 19:34:44 +08:00
|
|
|
if self.cache_config.cache_dtype == "auto":
|
|
|
|
|
self.cache_dtype = self.model_config.dtype
|
|
|
|
|
else:
|
|
|
|
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
|
|
|
self.cache_config.cache_dtype]
|
|
|
|
|
|
|
|
|
|
if self.model_config.trust_remote_code:
|
|
|
|
|
# note: lazy import to avoid importing torch before initializing
|
|
|
|
|
from vllm.utils import init_cached_hf_modules
|
|
|
|
|
init_cached_hf_modules()
|
|
|
|
|
|
2025-04-15 10:24:02 +08:00
|
|
|
self.profiler = self._init_profiler()
|
2025-03-20 19:34:44 +08:00
|
|
|
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
def sleep(self, level: int = 1) -> None:
|
2025-06-27 09:14:43 +08:00
|
|
|
if not sleep_mode_enabled():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
|
|
|
|
|
)
|
2025-06-06 21:54:02 +08:00
|
|
|
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
|
|
|
|
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
|
|
|
|
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
|
|
|
|
used_bytes = total - free_bytes_after_sleep
|
|
|
|
|
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
|
|
|
|
logger.info(
|
|
|
|
|
"Sleep mode freed %.2f GiB memory, "
|
|
|
|
|
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
|
|
|
|
used_bytes / GiB_bytes)
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
|
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
2025-06-27 09:14:43 +08:00
|
|
|
if not sleep_mode_enabled():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
|
|
|
|
|
)
|
2025-06-06 21:54:02 +08:00
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
allocator.wake_up(tags=tags)
|
Add sleep mode feature for Ascend NPU (#513)
### What this PR does / why we need it?
This PR adds sleep mode feature for vllm-ascend, when sleeps, we do
mainly two things:
- offload model weights
- discard kv cache
RLHF tools(such as https://github.com/volcengine/verl and
https://github.com/OpenRLHF/OpenRLHF) have a strong need of sleep mode
to accelerate the training process.
This PR may solve #375 and #320 .
### Does this PR introduce _any_ user-facing change?
No existing user interfaces changed.
Users will have two new methods(`sleep()` and `wake_up()`) to use.
### How was this patch tested?
This PR is tested with Qwen/Qwen2.5-0.5B-Instruct.
At first, we have free NPU memory M1.
After `llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)`
executed, we have free NPU memory M2. M2 < M1.
Then we call `llm.sleep(level=1)`, we have free NPU memory M3.
We have M3 > M2, M3 is very close to M1.
Plus, we have the same output tokens before sleep and after wake up,
with the config of `SamplingParams(temperature=0, max_tokens=10)` and
with the same input tokens of course.
This PR is utilizing the CMake procedure of #371 , thanks a lot.
Signed-off-by: Shuqiao Li <celestialli@outlook.com>
2025-04-18 13:11:39 +08:00
|
|
|
|
2025-06-16 21:03:16 +08:00
|
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
|
|
|
num_cpu_blocks: int) -> None:
|
|
|
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
|
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
def init_device(self):
|
2025-06-25 16:20:14 +08:00
|
|
|
device = torch.device(f"npu:{self.local_rank}")
|
|
|
|
|
NPUPlatform.set_device(device)
|
|
|
|
|
NPUPlatform.empty_cache()
|
|
|
|
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
|
|
|
|
|
2025-03-20 19:34:44 +08:00
|
|
|
# Initialize the distributed environment.
|
2025-04-15 10:24:02 +08:00
|
|
|
self._init_worker_distributed_environment()
|
2025-03-20 19:34:44 +08:00
|
|
|
# Set random seed.
|
2025-06-25 16:20:14 +08:00
|
|
|
NPUPlatform.seed_everything(self.model_config.seed)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
# Init ModelRunner here, so that we have access to self.device.
|
2025-06-25 16:20:14 +08:00
|
|
|
self.model_runner = NPUModelRunner(self.vllm_config, device)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def determine_available_memory(self) -> int:
|
|
|
|
|
# Profile the memory usage of the model and get the maximum number of
|
|
|
|
|
# cache blocks that can be allocated with the remaining free memory.
|
2025-06-06 21:54:02 +08:00
|
|
|
NPUPlatform.clear_npu_memory()
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
|
|
|
|
# of the model.
|
2025-06-06 21:54:02 +08:00
|
|
|
_, total_npu_memory = NPUPlatform.mem_get_info()
|
2025-03-20 19:34:44 +08:00
|
|
|
self.model_runner.profile_run()
|
|
|
|
|
|
|
|
|
|
# Calculate the number of blocks that can be allocated with the
|
|
|
|
|
# profiled peak memory.
|
2025-06-06 21:54:02 +08:00
|
|
|
free_npu_memory, _ = NPUPlatform.mem_get_info()
|
2025-03-20 19:34:44 +08:00
|
|
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
|
|
|
|
# GPU did not change their memory usage during the profiling.
|
2025-06-06 21:54:02 +08:00
|
|
|
assert self.init_npu_memory > free_npu_memory, (
|
2025-03-20 19:34:44 +08:00
|
|
|
"Error in memory profiling. "
|
|
|
|
|
f"Initial free memory {self.init_npu_memory}, current free memory"
|
|
|
|
|
f" {free_npu_memory}. This happens when the NPU memory was "
|
|
|
|
|
"not properly cleaned up before initializing the vLLM instance.")
|
|
|
|
|
|
2025-06-06 21:54:02 +08:00
|
|
|
# Get the peak memory allocation recorded by torch
|
|
|
|
|
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
|
2025-03-20 19:34:44 +08:00
|
|
|
# TODO: don`t need impl this func after empty_cache in
|
|
|
|
|
# Worker.determine_num_available_blocks() unified`
|
2025-04-15 10:24:02 +08:00
|
|
|
NPUPlatform.empty_cache()
|
2025-06-06 21:54:02 +08:00
|
|
|
torch_allocated_bytes = torch_npu.npu.memory_stats(
|
|
|
|
|
)["allocated_bytes.all.current"]
|
|
|
|
|
total_allocated_bytes = torch_npu.npu.mem_get_info(
|
|
|
|
|
)[1] - torch_npu.npu.mem_get_info()[0]
|
|
|
|
|
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
|
|
|
|
if non_torch_allocations > 0:
|
|
|
|
|
peak_memory += non_torch_allocations
|
2025-07-07 22:37:14 +08:00
|
|
|
available_kv_cache_memory = int(
|
2025-06-06 21:54:02 +08:00
|
|
|
total_npu_memory * self.cache_config.gpu_memory_utilization -
|
|
|
|
|
peak_memory)
|
2025-07-07 22:37:14 +08:00
|
|
|
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
|
|
|
|
|
)
|
|
|
|
|
if get_ascend_config().torchair_graph_config.enabled:
|
|
|
|
|
if check_torchair_cache_exist(
|
|
|
|
|
) and check_kv_cache_bytes_cache_exist():
|
|
|
|
|
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
|
|
|
|
|
torch.distributed.get_rank())
|
|
|
|
|
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
|
|
|
|
|
)
|
|
|
|
|
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
|
|
|
|
|
return old_kv_cache_bytes
|
|
|
|
|
else:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
|
|
|
|
|
)
|
|
|
|
|
delete_torchair_cache_file()
|
|
|
|
|
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
|
|
|
|
|
available_kv_cache_memory -= bytes_floating_tolerance
|
|
|
|
|
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
|
|
|
|
|
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
|
|
|
|
|
|
|
|
|
return available_kv_cache_memory
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def execute_model(
|
|
|
|
|
self,
|
|
|
|
|
scheduler_output: "SchedulerOutput",
|
|
|
|
|
) -> Optional[ModelRunnerOutput]:
|
2025-07-11 15:30:51 +08:00
|
|
|
intermediate_tensors = None
|
|
|
|
|
if not get_pp_group().is_first_rank:
|
|
|
|
|
intermediate_tensors = IntermediateTensors(
|
|
|
|
|
get_pp_group().recv_tensor_dict(
|
|
|
|
|
all_gather_group=get_tp_group()))
|
|
|
|
|
|
|
|
|
|
output = self.model_runner.execute_model(scheduler_output,
|
|
|
|
|
intermediate_tensors)
|
|
|
|
|
parallel_config = self.vllm_config.parallel_config
|
|
|
|
|
if parallel_config.distributed_executor_backend != "external_launcher" \
|
|
|
|
|
and not get_pp_group().is_last_rank:
|
|
|
|
|
assert isinstance(output, IntermediateTensors)
|
|
|
|
|
get_pp_group().send_tensor_dict(output.tensors,
|
|
|
|
|
all_gather_group=get_tp_group())
|
|
|
|
|
return None
|
|
|
|
|
assert isinstance(output, ModelRunnerOutput)
|
2025-06-03 11:07:33 +08:00
|
|
|
return output if self.is_driver_worker else None
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def load_model(self) -> None:
|
2025-06-06 21:54:02 +08:00
|
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
assert allocator.get_current_usage() == 0, (
|
|
|
|
|
"Sleep mode can only be "
|
|
|
|
|
"used for one instance per process.")
|
|
|
|
|
context = allocator.use_memory_pool(tag="weights")
|
|
|
|
|
else:
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
context = nullcontext() # type: ignore
|
|
|
|
|
with context:
|
|
|
|
|
self.model_runner.load_model()
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def compile_or_warm_up_model(self) -> None:
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
2025-03-20 19:34:44 +08:00
|
|
|
if not self.model_config.enforce_eager:
|
support aclgraph (#426)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
This PR supports the access of vllm-acend to the piecewise_graph feature
provided by the v1 engine.
1. register unifiled_ascend_attention_with_output for piecewise_graph to
split graph.
2. support NPUGraph to accelerate kernel launch.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
support npugraph to default, Users can disenable the npugraph feature by
configuring enforce_eager.
This has corresponding requirements for the versions of torch_npu and
CANN, and they need to support graph capture.
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
it turn to default
---------
Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
|
|
|
warmup_sizes = [
|
|
|
|
|
x for x in warmup_sizes if x not in
|
|
|
|
|
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
|
|
|
|
]
|
|
|
|
|
for size in sorted(warmup_sizes, reverse=True):
|
|
|
|
|
logger.info("Compile and warming up model for size %d", size)
|
|
|
|
|
self.model_runner._dummy_run(size)
|
|
|
|
|
if not self.model_config.enforce_eager:
|
|
|
|
|
self.model_runner.capture_model()
|
2025-03-20 19:34:44 +08:00
|
|
|
# Reset the seed to ensure that the random state is not affected by
|
|
|
|
|
# the model initialization and profiling.
|
2025-06-25 16:20:14 +08:00
|
|
|
NPUPlatform.seed_everything(self.model_config.seed)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def get_model(self) -> nn.Module:
|
|
|
|
|
return self.model_runner.get_model()
|
|
|
|
|
|
2025-03-28 19:34:23 +08:00
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
2025-03-20 19:34:44 +08:00
|
|
|
return self.model_runner.get_kv_cache_spec()
|
|
|
|
|
|
|
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
|
|
|
"""Allocate NPU KV cache with the specified kv_cache_config."""
|
2025-06-06 21:54:02 +08:00
|
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
|
|
|
allocator = CaMemAllocator.get_instance()
|
|
|
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
|
|
|
|
else:
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
context = nullcontext() # type: ignore
|
|
|
|
|
with context:
|
|
|
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
|
|
|
|
def profile(self, is_start: bool = True):
|
|
|
|
|
if self.profiler is None:
|
|
|
|
|
raise RuntimeError("Profiler is not enabled.")
|
|
|
|
|
if is_start:
|
|
|
|
|
self.profiler.start()
|
|
|
|
|
else:
|
|
|
|
|
self.profiler.stop()
|
|
|
|
|
|
2025-05-22 19:20:51 +08:00
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
|
|
|
return self.model_runner.add_lora(lora_request)
|
|
|
|
|
|
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
|
|
|
return self.model_runner.remove_lora(lora_id)
|
|
|
|
|
|
|
|
|
|
def list_loras(self) -> set[int]:
|
|
|
|
|
return self.model_runner.list_loras()
|
|
|
|
|
|
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
|
|
|
return self.model_runner.pin_lora(lora_id)
|
|
|
|
|
|
2025-05-12 17:31:29 +08:00
|
|
|
def execute_dummy_batch(self) -> None:
|
2025-06-04 18:31:41 +08:00
|
|
|
runner = self.model_runner
|
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
|
|
|
max_num_tokens = 1
|
|
|
|
|
with_prefill = False
|
2025-06-04 18:31:41 +08:00
|
|
|
if runner.dp_size > 1:
|
|
|
|
|
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
|
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
|
|
|
max_num_tokens, with_prefill)
|
|
|
|
|
if runner.torchair_graph_enabled and not with_prefill:
|
|
|
|
|
max_num_tokens = runner.select_torchair_padded_batch_size(
|
|
|
|
|
max_num_tokens)
|
|
|
|
|
runner._dummy_run(max_num_tokens,
|
2025-06-04 18:31:41 +08:00
|
|
|
is_compile=False,
|
|
|
|
|
with_prefill=with_prefill)
|
2025-05-12 17:31:29 +08:00
|
|
|
|
2025-04-15 10:24:02 +08:00
|
|
|
def _init_worker_distributed_environment(self) -> None:
|
|
|
|
|
"""Initialize the distributed environment."""
|
2025-04-19 17:38:18 +08:00
|
|
|
parallel_config = self.vllm_config.parallel_config
|
2025-04-15 10:24:02 +08:00
|
|
|
init_distributed_environment(self.parallel_config.world_size,
|
|
|
|
|
self.rank, self.distributed_init_method,
|
|
|
|
|
self.local_rank, "hccl")
|
|
|
|
|
ensure_model_parallel_initialized(
|
|
|
|
|
self.parallel_config.tensor_parallel_size,
|
|
|
|
|
self.parallel_config.pipeline_parallel_size)
|
2025-05-30 15:17:11 +08:00
|
|
|
init_ascend_model_parallel(
|
|
|
|
|
parallel_config.expert_parallel_size,
|
|
|
|
|
parallel_config.expert_tensor_parallel_size,
|
2025-06-04 18:31:41 +08:00
|
|
|
parallel_config.world_size_across_dp,
|
2025-05-30 15:17:11 +08:00
|
|
|
)
|
2025-04-15 15:11:35 +08:00
|
|
|
ensure_kv_transfer_initialized(self.vllm_config)
|
2025-04-15 10:24:02 +08:00
|
|
|
|
|
|
|
|
def _init_profiler(self):
|
|
|
|
|
# Torch profiler. Enabled and configured through env vars:
|
|
|
|
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
|
|
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
|
|
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
|
|
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
|
|
|
torch_profiler_trace_dir)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-15 10:24:02 +08:00
|
|
|
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
|
|
|
|
export_type=torch_npu.profiler.ExportType.Text,
|
2025-06-06 20:25:59 +08:00
|
|
|
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
|
2025-04-15 10:24:02 +08:00
|
|
|
msprof_tx=False,
|
|
|
|
|
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
|
|
|
|
l2_cache=False,
|
|
|
|
|
op_attr=False,
|
|
|
|
|
data_simplification=False,
|
|
|
|
|
record_op_args=False,
|
|
|
|
|
gc_detect_threshold=None,
|
|
|
|
|
)
|
2025-03-20 19:34:44 +08:00
|
|
|
|
2025-04-15 10:24:02 +08:00
|
|
|
return torch_npu.profiler.profile(
|
|
|
|
|
activities=[
|
|
|
|
|
torch_npu.profiler.ProfilerActivity.CPU,
|
|
|
|
|
torch_npu.profiler.ProfilerActivity.NPU,
|
|
|
|
|
],
|
2025-06-06 20:25:59 +08:00
|
|
|
with_stack=False,
|
|
|
|
|
profile_memory=False,
|
|
|
|
|
with_modules=False,
|
2025-04-15 10:24:02 +08:00
|
|
|
experimental_config=experimental_config,
|
|
|
|
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
|
|
|
|
torch_profiler_trace_dir))
|
|
|
|
|
else:
|
2025-04-18 12:23:32 +08:00
|
|
|
return None
|