Files
xc-llm-ascend/vllm_ascend/worker/worker_v1.py

391 lines
17 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
#
import copy
[Perf][V1] Fully overlap model execution (#2783) This PR is based on top of [#23569](https://github.com/vllm-project/vllm/pull/23569) and [#24219](https://github.com/vllm-project/vllm/pull/24219). ### What this PR does / why we need it? This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding. Expected speedup is 5-10% over the current async scheduling. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? server ``` python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\ --trust-remote-code --enforce-eager \ --distributed-executor-backend=mp \ -tp=4 \ --port 8006 \ --max-model-len 32000 \ --block-size 128 \ --gpu-memory-utilization 0.99 ``` client ``` python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \ --dataset-name random --random-input-len 2048 --random-output-len 2048 \ --ignore-eos\ --num-prompts 48 --max-concurrency 48 --request-rate inf --temperature 0 \ --metric-percentiles 90 --base-url http://localhost:8006 --save-result \ --result-dir $PROFILER_DIR ``` benchmark test based on Qwen3-32B TPOT result: ||forward async| scheduler async |sync| |-|-|-|-| |avg|41.73|41.86|44.20| |improve0|0.3%|0|0| |improve1|5.58%|0|0| benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result: ||forward async|sync| |-|-|-| |avg|23.22|29.16| |improve|20.3%|0| - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/e93f4cc9e37484009f74e15d3111a1f335c532a5 Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com> Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com> Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
2025-09-11 16:35:36 +08:00
from typing import Optional, Union
import torch
import torch.nn as nn
import torch_npu
import vllm.envs as envs_vllm
support aclgraph (#426) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
[Perf][V1] Fully overlap model execution (#2783) This PR is based on top of [#23569](https://github.com/vllm-project/vllm/pull/23569) and [#24219](https://github.com/vllm-project/vllm/pull/24219). ### What this PR does / why we need it? This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding. Expected speedup is 5-10% over the current async scheduling. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? server ``` python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\ --trust-remote-code --enforce-eager \ --distributed-executor-backend=mp \ -tp=4 \ --port 8006 \ --max-model-len 32000 \ --block-size 128 \ --gpu-memory-utilization 0.99 ``` client ``` python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \ --dataset-name random --random-input-len 2048 --random-output-len 2048 \ --ignore-eos\ --num-prompts 48 --max-concurrency 48 --request-rate inf --temperature 0 \ --metric-percentiles 90 --base-url http://localhost:8006 --save-result \ --result-dir $PROFILER_DIR ``` benchmark test based on Qwen3-32B TPOT result: ||forward async| scheduler async |sync| |-|-|-|-| |avg|41.73|41.86|44.20| |improve0|0.3%|0|0| |improve1|5.58%|0|0| benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result: ||forward async|sync| |-|-|-| |avg|23.22|29.16| |improve|20.3%|0| - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/e93f4cc9e37484009f74e15d3111a1f335c532a5 Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com> Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com> Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
2025-09-11 16:35:36 +08:00
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (init_ascend_soc_version,
register_ascend_customop, sleep_mode_enabled,
try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
["torch.npu.current_stream"],
TorchInGraphFunctionVariable,
) # noqa: E402
torch_non_c_binding_in_graph_functions_npu[
"torch.npu.stream"] = TorchInGraphFunctionVariable # noqa: E402
torch._dynamo.trace_rules.torch_name_rule_map.append(
torch_non_c_binding_in_graph_functions_npu) # noqa: E402
class NPUWorker(WorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
# Additional parameters for compatibility with vllm
**kwargs):
"""Initialize the worker for Ascend."""
# register patch for vllm
from vllm_ascend.utils import adapt_patch
adapt_patch()
# Register ops when worker init.
support aclgraph (#426) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
register_ascend_customop(vllm_config)
# init ascend config and soc version
init_ascend_config(vllm_config)
init_ascend_soc_version()
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
# Try to import mindie_turbo to accelerate vLLM inference.
try_register_lib(
"mindie_turbo",
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
)
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.profiler = self._init_profiler()
if sleep_mode_enabled():
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
def sleep(self, level: int = 1) -> None:
if not sleep_mode_enabled():
raise ValueError(
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone()
for name, buffer in model.named_buffers()
}
allocator = CaMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
if not sleep_mode_enabled():
raise ValueError(
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
)
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def _init_device(self):
device = torch.device(f"npu:{self.local_rank}")
NPUPlatform.set_device(device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
# Initialize the distributed environment.
self._init_worker_distributed_environment()
# Set random seed.
NPUPlatform.seed_everything(self.model_config.seed)
return device
def init_device(self):
device = self._init_device()
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = NPUModelRunner(self.vllm_config, device)
def determine_available_memory(self) -> int:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
NPUPlatform.clear_npu_memory()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
_, total_npu_memory = NPUPlatform.mem_get_info()
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
free_npu_memory, _ = NPUPlatform.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_npu_memory > free_npu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_npu_memory}, current free memory"
f" {free_npu_memory}. This happens when the NPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
# Get the peak memory allocation recorded by torch
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
# TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified`
NPUPlatform.empty_cache()
torch_allocated_bytes = torch_npu.npu.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch_npu.npu.mem_get_info(
)[1] - torch_npu.npu.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = int(
total_npu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
logger.info(
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
)
return available_kv_cache_memory
def execute_model(
self,
scheduler_output: "SchedulerOutput",
[Perf][V1] Fully overlap model execution (#2783) This PR is based on top of [#23569](https://github.com/vllm-project/vllm/pull/23569) and [#24219](https://github.com/vllm-project/vllm/pull/24219). ### What this PR does / why we need it? This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding. Expected speedup is 5-10% over the current async scheduling. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? server ``` python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\ --trust-remote-code --enforce-eager \ --distributed-executor-backend=mp \ -tp=4 \ --port 8006 \ --max-model-len 32000 \ --block-size 128 \ --gpu-memory-utilization 0.99 ``` client ``` python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \ --dataset-name random --random-input-len 2048 --random-output-len 2048 \ --ignore-eos\ --num-prompts 48 --max-concurrency 48 --request-rate inf --temperature 0 \ --metric-percentiles 90 --base-url http://localhost:8006 --save-result \ --result-dir $PROFILER_DIR ``` benchmark test based on Qwen3-32B TPOT result: ||forward async| scheduler async |sync| |-|-|-|-| |avg|41.73|41.86|44.20| |improve0|0.3%|0|0| |improve1|5.58%|0|0| benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result: ||forward async|sync| |-|-|-| |avg|23.22|29.16| |improve|20.3%|0| - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/e93f4cc9e37484009f74e15d3111a1f335c532a5 Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com> Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com> Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
2025-09-11 16:35:36 +08:00
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not has_kv_transfer_group():
return None
kv_connector_output = output.kv_connector_output
finished_sending = kv_connector_output.finished_sending
finished_recving = kv_connector_output.finished_recving
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
new_output.kv_connector_output = kv_connector_output
return new_output
[Perf][V1] Fully overlap model execution (#2783) This PR is based on top of [#23569](https://github.com/vllm-project/vllm/pull/23569) and [#24219](https://github.com/vllm-project/vllm/pull/24219). ### What this PR does / why we need it? This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding. Expected speedup is 5-10% over the current async scheduling. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? server ``` python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\ --trust-remote-code --enforce-eager \ --distributed-executor-backend=mp \ -tp=4 \ --port 8006 \ --max-model-len 32000 \ --block-size 128 \ --gpu-memory-utilization 0.99 ``` client ``` python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \ --dataset-name random --random-input-len 2048 --random-output-len 2048 \ --ignore-eos\ --num-prompts 48 --max-concurrency 48 --request-rate inf --temperature 0 \ --metric-percentiles 90 --base-url http://localhost:8006 --save-result \ --result-dir $PROFILER_DIR ``` benchmark test based on Qwen3-32B TPOT result: ||forward async| scheduler async |sync| |-|-|-|-| |avg|41.73|41.86|44.20| |improve0|0.3%|0|0| |improve1|5.58%|0|0| benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result: ||forward async|sync| |-|-|-| |avg|23.22|29.16| |improve|20.3%|0| - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/e93f4cc9e37484009f74e15d3111a1f335c532a5 Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com> Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com> Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
2025-09-11 16:35:36 +08:00
assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput))
return output
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
# Note: need to adapt for graph mode.
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
self.model_runner.eplb_warmup()
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
or []).copy()
if not self.model_config.enforce_eager:
support aclgraph (#426) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-04-23 20:56:24 +08:00
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
# may cause performance degradation at runtime.
self._warm_up_atb()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
NPUPlatform.seed_everything(self.model_config.seed)
def _warm_up_atb(self):
x = torch.rand((2, 4), dtype=torch.float16).npu()
weight = torch.rand((2, 4), dtype=torch.float16).npu()
c = torch.rand((4, 4), dtype=torch.float32).npu()
torch_npu._npu_matmul_add_fp32(x, weight, c)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate NPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_runner.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
init_distributed_environment(self.parallel_config.world_size,
self.rank, self.distributed_init_method,
self.local_rank, "hccl")
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs_vllm.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs_vllm.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
experimental_config = torch_npu.profiler._ExperimentalConfig(
export_type=torch_npu.profiler.ExportType.Text,
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
msprof_tx=False,
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
l2_cache=False,
op_attr=False,
data_simplification=False,
record_op_args=False,
gc_detect_threshold=None,
)
return torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU,
],
with_stack=envs_vllm.VLLM_TORCH_PROFILER_WITH_STACK,
profile_memory=envs_vllm.\
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_modules=False,
experimental_config=experimental_config,
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir))
else:
return None
def get_supported_pooling_tasks(self):
return self.model_runner.get_supported_pooling_tasks()
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
return self.model_runner.get_supported_tasks()
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()