# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # from typing import Any import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import logger from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cudagraph_utils import BatchExecutionDescriptor, ModelCudaGraphManager from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params class ModelAclGraphManager(ModelCudaGraphManager): """ACL Model Cuda Graph Manager for Ascend NPUs.""" def __init__( self, vllm_config: VllmConfig, device: torch.device, cudagraph_mode: CUDAGraphMode, decode_query_len: int, model_runner: Any, ): super().__init__( vllm_config, device, cudagraph_mode, decode_query_len, ) # set model runner attribute, so we can access attributes model runner # when call `run_fullgraph` method in CudaGraphManager, # then we don't need to # copy `execute_model` method in `NPUModelRunner` class. self.model_runner = model_runner # capture_sizes sorts in ascending order. self.capture_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) # vllm-ascend need to update graph params of attention backend. # so we need to set graph params before capture full graph. if super().needs_capture(): set_graph_params(self.capture_sizes) def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """Override run_fullgraph to update full graph params in run_fullgraph.""" num_tokens = desc.num_tokens logger.info_once(f"run_fullgraph with num_tokens={num_tokens}") ret = super().run_fullgraph(desc) positions = self.model_runner.input_buffers.positions[:num_tokens] # refer to vllm.v1.worker.gpu.dp_utils.sync_cudagraph_and_dp_padding to # calculate num_tokens_across_dp. num_tokens_across_dp = torch.full([self.model_runner.dp_size], num_tokens, device=self.device) with set_forward_context( self.model_runner.input_batch.attn_metadata, self.vllm_config, num_tokens=num_tokens, cudagraph_runtime_mode=desc.cg_mode, num_tokens_across_dp=num_tokens_across_dp, batch_descriptor=None, # Full graph model don't need batch_descriptor slot_mapping=self.model_runner.input_batch.slot_mappings, ): forward_context = get_forward_context() update_full_graph_params( # FIXME(Ronald1995): support hybrid attn backend list(self.model_runner.attn_backends.values())[0], self.model_runner.update_stream, forward_context, num_tokens, self.vllm_config, self.model_runner.speculative_config, positions.shape[0], ) return ret def capture( self, model: nn.Module, model_state: ModelState, input_buffers: InputBuffers, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, has_lora: bool = False, use_aux_hidden_state_outputs: bool = False, progress_bar_desc: str = "Capturing CUDA graphs", ) -> None: """Capture CUDA graphs for model forward pass.""" model = ModelWithContext(model) return super().capture( model, model_state, input_buffers, block_tables, attn_groups, kv_cache_config, has_lora, use_aux_hidden_state_outputs, progress_bar_desc, ) class ModelWithContext(nn.Module): """Define a wrapper model to inject forward context. so we can inherit vllm's CudaGraphManager._capture_full_graph. """ def __init__(self, original_model): super().__init__() self.original_model = original_model def forward(self, *args, **kwargs): # In warmup phase, capturing=False by default. # when capturing, we need to set capturing=True in forward context. if torch.npu.is_current_stream_capturing(): _EXTRA_CTX.capturing = True return self.original_model(*args, **kwargs)