### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
293 lines
12 KiB
Python
293 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import dataclasses
|
|
from collections.abc import Callable
|
|
from contextlib import ExitStack
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch_npu
|
|
import vllm.envs as envs
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
|
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.platforms import current_platform
|
|
|
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
|
|
|
from ..utils import weak_ref_tensors
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ACLGraphEntry:
|
|
batch_descriptor: BatchDescriptor
|
|
aclgraph: torch.npu.NPUGraph | None = None
|
|
output: Any | None = None
|
|
|
|
# for aclgraph debugging, track the input addresses
|
|
# during capture, and check if they are the same during replay
|
|
input_addresses: list[int] | None = None
|
|
|
|
|
|
class ACLGraphWrapper:
|
|
"""Wraps a runnable to add acl graph capturing and replaying ability. And
|
|
provide attribute access to the underlying `runnable` via `__getattr__`.
|
|
|
|
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
|
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
|
PIECEWISE).
|
|
2. At runtime, the wrapper receives a runtime_mode and a
|
|
batch_descriptor(key) from the forward context and blindly trust them
|
|
for aclgraph dispatching.
|
|
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
|
wrapper, just call the runnable directly.
|
|
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
|
the wrapper will perform aclgraph capture(if key does not exist, create
|
|
a new entry and cache it) or replay (if key exists in the cache).
|
|
|
|
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
|
runtime inputs into that buffers for replay. We assume implementing them
|
|
is done outside of the wrapper. That is because we do not make any
|
|
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
|
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
|
tracing and checking the input addresses to be consistent during replay is
|
|
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
runnable: Callable,
|
|
vllm_config: VllmConfig,
|
|
runtime_mode: CUDAGraphMode,
|
|
cudagraph_options: CUDAGraphOptions | None = None,
|
|
):
|
|
self.runnable = runnable
|
|
self.vllm_config = vllm_config
|
|
self.runtime_mode = runtime_mode
|
|
self.compilation_config = vllm_config.compilation_config
|
|
|
|
self.first_run_finished = False
|
|
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
|
|
|
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
|
|
# need to initialize a ACLGraphWrapper.
|
|
assert self.runtime_mode != CUDAGraphMode.NONE
|
|
self.graph_pool = current_platform.get_global_graph_pool()
|
|
|
|
if cudagraph_options is None:
|
|
cudagraph_options = CUDAGraphOptions()
|
|
self.aclgraph_options = cudagraph_options
|
|
# the entries for different batch descriptors that we need to capture
|
|
# aclgraphs for.
|
|
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
|
|
|
|
def __getattr__(self, key: str):
|
|
# allow accessing the attributes of the runnable.
|
|
if hasattr(self.runnable, key):
|
|
return getattr(self.runnable, key)
|
|
raise AttributeError(f"Attribute {key} not exists in the runnable of aclgraph wrapper: {self.runnable}")
|
|
|
|
def unwrap(self) -> Callable:
|
|
# in case we need to access the original runnable.
|
|
return self.runnable
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
forward_context = get_forward_context()
|
|
batch_descriptor = forward_context.batch_descriptor
|
|
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
|
|
|
if aclgraph_runtime_mode == CUDAGraphMode.NONE or aclgraph_runtime_mode != self.runtime_mode:
|
|
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
|
# running without aclgraphs.
|
|
# We do not trigger capture/replay if the runtime mode is not
|
|
# matches. This enables properly dispatching to the correct
|
|
# CUDAGraphWrapper when nesting multiple instances with different
|
|
# runtime modes.
|
|
return self.runnable(*args, **kwargs)
|
|
|
|
if batch_descriptor not in self.concrete_aclgraph_entries:
|
|
# create a new entry for this batch descriptor
|
|
self.concrete_aclgraph_entries[batch_descriptor] = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
|
|
|
entry = self.concrete_aclgraph_entries[batch_descriptor]
|
|
|
|
if entry.aclgraph is None:
|
|
if self.aclgraph_options.debug_log_enable:
|
|
# Since we capture aclgraph for many different shapes and
|
|
# capturing is fast, we don't need to log it for every
|
|
# shape. E.g. we only log it for the first subgraph in
|
|
# piecewise mode.
|
|
logger.debug("Capturing a aclgraph on (%s,%s)", self.runtime_mode.name, entry.batch_descriptor)
|
|
# validate that aclgraph capturing is legal at this point.
|
|
validate_cudagraph_capturing_enabled()
|
|
|
|
input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
|
|
entry.input_addresses = input_addresses
|
|
aclgraph = torch.npu.NPUGraph()
|
|
|
|
with ExitStack() as stack:
|
|
if self.aclgraph_options.gc_disable:
|
|
# during every model forward for piecewise aclgraph
|
|
# mode, we will capture many pieces of aclgraphs
|
|
# (roughly one per layer). running gc again and again
|
|
# across layers will make the aclgraph capture very slow.
|
|
# therefore, we only run gc for the first graph,
|
|
# and disable gc for the rest of the graphs.
|
|
stack.enter_context(patch("gc.collect", lambda: None))
|
|
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))
|
|
|
|
# mind-exploding: carefully manage the reference and memory.
|
|
forward_context.capturing = True
|
|
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
|
# `output` is managed by pytorch's aclgraph pool
|
|
output = self.runnable(*args, **kwargs)
|
|
if self.aclgraph_options.weak_ref_output:
|
|
# by converting it to weak ref,
|
|
# the original `output` will immediately be released
|
|
# to save memory. It is only safe to do this for
|
|
# the last graph in piecewise aclgraph mode, because
|
|
# the output of the last graph will not be used by
|
|
# any other acl graph.
|
|
output = weak_ref_tensors(output)
|
|
|
|
# here we always use weak ref for the workspaces
|
|
# to save memory
|
|
global _graph_params
|
|
global _draft_graph_params
|
|
weak_ref_workspaces(_graph_params)
|
|
weak_ref_workspaces(_draft_graph_params)
|
|
|
|
# here we always use weak ref for the output
|
|
# to save memory
|
|
entry.output = weak_ref_tensors(output)
|
|
entry.aclgraph = aclgraph
|
|
|
|
compilation_counter.num_cudagraph_captured += 1
|
|
|
|
# important: we need to return the output, rather than
|
|
# the weak ref of the output, so that pytorch can correctly
|
|
# manage the memory during acl graph capture
|
|
return output
|
|
|
|
if self.is_debugging_mode:
|
|
# check if the input addresses are the same
|
|
new_input_addresses = [x.data_ptr() for x in args if isinstance(x, torch.Tensor)]
|
|
assert new_input_addresses == entry.input_addresses, (
|
|
f"Input addresses for aclgraphs are different "
|
|
f"during replay. Expected {entry.input_addresses}, "
|
|
f"got {new_input_addresses}"
|
|
)
|
|
|
|
logger.info_once("Replaying aclgraph")
|
|
# In async scheduling or multi-threaded (MT) scenarios, it is possible that
|
|
# the CPU's record event (from update_attn_params) for the iteration i completes
|
|
# before the grph replay of iteration i-1.
|
|
# To ensure proper ordering, we must call synchronize here before replaying,
|
|
# so that update_attn_params only executes after the previous graph replay has fully completed.
|
|
# If we do not in main model and in full-graph mode when using merge-eagle-graph,
|
|
# we do not need to synchronize.
|
|
use_eagle = (
|
|
self.vllm_config.speculative_config.method in ("eagle", "eagle3")
|
|
if self.vllm_config.speculative_config
|
|
else False
|
|
)
|
|
if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle:
|
|
torch.npu.current_stream().synchronize()
|
|
entry.aclgraph.replay()
|
|
return entry.output
|
|
|
|
|
|
def weak_ref_workspaces(params):
|
|
if params is None:
|
|
return
|
|
for num_tokens in params.workspaces:
|
|
if params.workspaces[num_tokens] is None:
|
|
continue
|
|
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
|
|
|
|
|
|
def update_full_graph_params(
|
|
attn_backend,
|
|
update_stream,
|
|
forward_context,
|
|
num_tokens,
|
|
vllm_config,
|
|
speculative_config=None,
|
|
num_dcp_pcp_tokens=None,
|
|
draft_attn_metadatas=None,
|
|
):
|
|
impl_cls = attn_backend.get_impl_cls()
|
|
impl_cls.update_graph_params(
|
|
update_stream,
|
|
forward_context,
|
|
num_tokens,
|
|
vllm_config,
|
|
speculative_config,
|
|
num_dcp_pcp_tokens,
|
|
draft_attn_metadatas,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class GraphParams:
|
|
events: dict[int, list[torch.npu.ExternalEvent]]
|
|
workspaces: dict[int, torch.Tensor]
|
|
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
|
attn_params: dict[int, list[tuple]]
|
|
|
|
|
|
_graph_params: GraphParams | None = None
|
|
|
|
|
|
def set_graph_params(aclgraph_capture_sizes: list[int]):
|
|
global _graph_params
|
|
if _graph_params is not None:
|
|
raise ValueError("Graph parameters have already been set!")
|
|
_graph_params = GraphParams(
|
|
{size: [] for size in aclgraph_capture_sizes},
|
|
{size: None for size in aclgraph_capture_sizes},
|
|
{size: [] for size in aclgraph_capture_sizes},
|
|
{size: [] for size in aclgraph_capture_sizes},
|
|
)
|
|
|
|
|
|
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
|
|
global _graph_params
|
|
if _graph_params is not None:
|
|
_graph_params.workspaces[num_tokens] = workspace
|
|
|
|
|
|
def get_graph_params():
|
|
return _graph_params
|
|
|
|
|
|
_draft_graph_params: GraphParams | None = None
|
|
|
|
|
|
def set_draft_graph_params(aclgraph_capture_sizes: list[int]):
|
|
global _draft_graph_params
|
|
if _draft_graph_params is not None:
|
|
raise ValueError("DraftGraph parameters have already been set!")
|
|
_draft_graph_params = GraphParams(
|
|
{size: [] for size in aclgraph_capture_sizes},
|
|
{size: None for size in aclgraph_capture_sizes},
|
|
{size: [] for size in aclgraph_capture_sizes},
|
|
{size: [] for size in aclgraph_capture_sizes},
|
|
)
|
|
|
|
|
|
def update_draft_graph_params_workspaces(num_tokens: int, workspace: Any):
|
|
global _draft_graph_params
|
|
if _draft_graph_params is not None:
|
|
_draft_graph_params.workspaces[num_tokens] = workspace
|
|
|
|
|
|
def get_draft_graph_params():
|
|
return _draft_graph_params
|