[Feature] support aclgraph for model runner v2 (#7110)

### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model. 

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-03-13 09:11:46 +08:00
committed by GitHub
parent 1f71da80eb
commit c980e68d40
52 changed files with 840 additions and 309 deletions

View File

@@ -17,12 +17,16 @@
# This file is a part of the vllm-ascend project.
#
import functools
import numpy as np
import torch
import vllm
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.config.compilation import CUDAGraphMode
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.input_batch import (
combine_sampled_and_draft_tokens,
@@ -32,23 +36,23 @@ from vllm.v1.worker.gpu.input_batch import (
)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import set_weight_prefetch_method
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata, build_attn_state
from vllm_ascend.worker.v2.attn_utils import build_attn_state
from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.spec_decode import init_speculator
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
from vllm_ascend.worker.v2.states import AscendRequestState
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
logger = init_logger(__name__)
from vllm_ascend.worker.v2.utils import block_table_wrapper, model_states_wrapper, torch_cuda_wrapper
class NPUModelRunner(GPUModelRunner):
"""Model runner for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device):
with torch_cuda_wrapper():
with torch_cuda_wrapper(), block_table_wrapper(), model_states_wrapper():
super().__init__(vllm_config, device)
# because we will override these attribute, delete these attribute to
@@ -62,8 +66,9 @@ class NPUModelRunner(GPUModelRunner):
# NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager(
self.vllm_config,
self.uses_mrope,
self.use_aux_hidden_state_outputs,
self.device,
self,
)
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
@@ -96,6 +101,7 @@ class NPUModelRunner(GPUModelRunner):
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
@@ -113,6 +119,59 @@ class NPUModelRunner(GPUModelRunner):
pin_memory=True,
)
# Ascend-specific configurations
self.ascend_config = get_ascend_config()
# set this just the same as model runner v1, or it will raise error.
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
# we need to update full graph params in run_fullgraph,
# so create a stream to update full graph params.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream: torch.npu.Stream = torch.npu.Stream()
# we need to use return value of `get_cudagraph_and_dp_padding`
# to set forward_context in `run_fullgraph`.
# so we can inherit `execute_model` method.
self.cudagraph_and_dp_padding: tuple[int, torch.Tensor | None, int] | None = None
# we need to use input_batch to set forward_context in run_fullgraph.
# so we can inherit `execute_model` method.
self.input_batch: AscendInputBatch | None = None
@torch.inference_mode()
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> ModelRunnerOutput | IntermediateTensors | None:
"""Override GPUModelRunner.execute_model for Ascend NPUs by there reasons:
1. when run fullgraph, we need to use ret value of `get_cudagraph_and_dp_padding`
to set forward_context in `run_fullgraph`.
"""
# use closure to store return value of get_cudagraph_and_dp_padding in model runner.
def wrapper(func):
@functools.wraps(func)
def inner(*args, **kwargs):
self.cudagraph_and_dp_padding = func(*args, **kwargs)
return self.cudagraph_and_dp_padding
return inner
if self.cudagraph_and_dp_padding is None:
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding = wrapper(
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding
)
return super().execute_model(
scheduler_output,
intermediate_tensors,
dummy_run,
skip_attn_for_dummy_run,
)
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
@@ -185,33 +244,40 @@ class NPUModelRunner(GPUModelRunner):
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc.
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
# NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
# See _pad_query_start_loc_for_fia.
query_start_loc_np = np.empty(self.max_num_reqs + 2, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
# This is only required for vllm-ascend.
query_start_loc_np, num_reqs_padded = self._pad_query_start_loc_for_fia(
num_tokens_padded=num_tokens_after_padding,
num_tokens=num_tokens,
num_reqs=num_reqs,
query_start_loc_np=query_start_loc_np,
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens.
prepare_prefill_inputs(
self.input_buffers.input_ids,
self.req_states.next_prefill_tokens,
idx_mapping,
query_start_loc,
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
prepare_prefill_inputs(
self.input_buffers.input_ids,
self.req_states.next_prefill_tokens,
idx_mapping,
query_start_loc,
self.req_states.all_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
@@ -223,14 +289,8 @@ class NPUModelRunner(GPUModelRunner):
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Prepare M-RoPE positions.
if self.uses_mrope:
self.mrope_states.prepare_mrope_positions(
idx_mapping,
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
# Pad for full CUDA graph mode.
self.input_buffers.seq_lens_np[num_reqs_padded:] = 0
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
@@ -246,43 +306,12 @@ class NPUModelRunner(GPUModelRunner):
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, self.input_buffers.positions[:num_tokens]
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, self.kv_cache_config)
# Layer name -> attention metadata.
# TODO(Ronald1995): try to add a new method `build_attn_metadata` in
# vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs`
# method like this.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
# extra attributes for ascend npus.
seq_lens_np=self.input_buffers.seq_lens_np,
num_computed_tokens_cpu=self.req_states.num_computed_tokens_cpu[idx_mapping_cpu],
attn_state=attn_state,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
return AscendInputBatch(
self.input_batch = AscendInputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
num_reqs=num_reqs_padded,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
@@ -294,18 +323,18 @@ class NPUModelRunner(GPUModelRunner):
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
dcp_local_seq_lens=None, # TODO(Ronald1995): support cp.
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
# extra attributes for ascend npus.
seq_lens_np=self.input_buffers.seq_lens_np,
attn_state=attn_state,
)
return self.input_batch
def postprocess(
self,
@@ -352,7 +381,7 @@ class NPUModelRunner(GPUModelRunner):
self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index]
# update seq_lens_cpu
for i, req_id in enumerate(req_ids):
for i, req_id in enumerate(req_ids): # type: ignore
req_index = self.req_states.req_id_to_index[req_id]
num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index]
self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id]
@@ -361,3 +390,44 @@ class NPUModelRunner(GPUModelRunner):
# TODO(Ronald1995): just define the method in case calling error in
# worker, implement it in the future.
pass
def _pad_query_start_loc_for_fia(
self,
num_tokens_padded: int,
num_tokens: int,
num_reqs: int,
query_start_loc_np: np.ndarray,
max_query_len: int,
) -> tuple[np.ndarray, int]:
"""
This function is only designed to satisfied the constraint that when the layout is TND,
the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`.
"""
assert self.cudagraph_and_dp_padding is not None
_num_tokens_after_padding, _num_tokens_across_dp, synced_cudagraph_mode = self.cudagraph_and_dp_padding
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_runtime_mode != CUDAGraphMode.FULL:
return query_start_loc_np, num_reqs
uniform_decode_query_len = self.cudagraph_manager.uniform_decode_query_len
is_uniform_decode = self.cudagraph_manager.is_uniform_decode(
num_reqs=num_reqs,
num_tokens=num_tokens,
max_query_len=max_query_len,
)
if is_uniform_decode:
# Uniform-batch case: num_reqs must be no greater than num_reqs_padded
num_reqs_padded = num_tokens_padded // uniform_decode_query_len
last_loc = query_start_loc_np[num_reqs]
query_start_loc_np[num_reqs + 1 : num_reqs_padded + 1] = (
np.arange(1, num_reqs_padded + 1 - num_reqs) * uniform_decode_query_len + last_loc
)
else:
# Mixed-batch case: num_reqs must equal num_reqs_padded
num_reqs_padded = min(num_tokens_padded, self.max_num_reqs)
# Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly
query_start_loc_np[num_reqs_padded + 1] = num_tokens_padded
num_reqs_padded = num_reqs_padded + 1
return query_start_loc_np, num_reqs_padded