Files
xc-llm-ascend/vllm_ascend/worker/v2/aclgraph_utils.py
Ronald d96440924a adapt to main2main for model runner v2 (#7578)
### What this PR does / why we need it?
This PR aims to adapt to newest commit of vllm main branch for model
runner v2. please refer to
https://github.com/vllm-project/vllm-ascend/issues/5208
### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.18.0
- vLLM main:
ed359c497a

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
2026-03-25 09:08:44 +08:00

141 lines
5.6 KiB
Python

# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import logger
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import BatchExecutionDescriptor, ModelCudaGraphManager
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params
class ModelAclGraphManager(ModelCudaGraphManager):
"""ACL Model Cuda Graph Manager for Ascend NPUs."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
decode_query_len: int,
model_runner: Any,
):
super().__init__(
vllm_config,
device,
cudagraph_mode,
decode_query_len,
)
# set model runner attribute, so we can access attributes model runner
# when call `run_fullgraph` method in CudaGraphManager,
# then we don't need to # copy `execute_model` method in `NPUModelRunner` class.
self.model_runner = model_runner
# capture_sizes sorts in ascending order.
self.capture_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
# vllm-ascend need to update graph params of attention backend.
# so we need to set graph params before capture full graph.
if super().needs_capture():
set_graph_params(self.capture_sizes)
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""Override run_fullgraph to update full graph params in run_fullgraph."""
num_tokens = desc.num_tokens
logger.info_once(f"run_fullgraph with num_tokens={num_tokens}")
ret = super().run_fullgraph(desc)
positions = self.model_runner.input_buffers.positions[:num_tokens]
# refer to vllm.v1.worker.gpu.dp_utils.sync_cudagraph_and_dp_padding to
# calculate num_tokens_across_dp.
num_tokens_across_dp = torch.full([self.model_runner.dp_size], num_tokens, device=self.device)
with set_forward_context(
self.model_runner.input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=desc.cg_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=None, # Full graph model don't need batch_descriptor
slot_mapping=self.model_runner.input_batch.slot_mappings,
):
forward_context = get_forward_context()
update_full_graph_params(
# FIXME(Ronald1995): support hybrid attn backend
list(self.model_runner.attn_backends.values())[0],
self.model_runner.update_stream,
forward_context,
num_tokens,
self.vllm_config,
self.model_runner.speculative_config,
positions.shape[0],
)
return ret
def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
use_aux_hidden_state_outputs: bool = False,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
"""Capture CUDA graphs for model forward pass."""
model = ModelWithContext(model)
return super().capture(
model,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
has_lora,
use_aux_hidden_state_outputs,
progress_bar_desc,
)
class ModelWithContext(nn.Module):
"""Define a wrapper model to inject forward context.
so we can inherit vllm's CudaGraphManager._capture_full_graph.
"""
def __init__(self, original_model):
super().__init__()
self.original_model = original_model
def forward(self, *args, **kwargs):
# In warmup phase, capturing=False by default.
# when capturing, we need to set capturing=True in forward context.
if torch.npu.is_current_stream_capturing():
_EXTRA_CTX.capturing = True
return self.original_model(*args, **kwargs)