Fix some ci issue and refactor modelrunner (#2445)
### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
4d9c61993a
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
186
vllm_ascend/compilation/acl_graph.py
Normal file
186
vllm_ascend/compilation/acl_graph.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ACLGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
aclgraph: Optional[torch.npu.NPUGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for aclgraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class ACLGraphWrapper:
|
||||
"""Wraps a runnable to add acl graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the aclgraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for aclgraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform aclgraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: ACLGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
graph_pool: Any = None,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.graph_pool = graph_pool
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
|
||||
# need to initialize a ACLGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.aclgraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# aclgraphs for.
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
|
||||
= {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(f"Attribute {key} not exists in the runnable of "
|
||||
f"aclgraph wrapper: {self.runnable}")
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
aclgraph_runtime_mode != self.runtime_mode:
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without aclgraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
if batch_descriptor not in self.concrete_aclgraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_aclgraph_entries[batch_descriptor] = \
|
||||
ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
entry = self.concrete_aclgraph_entries[batch_descriptor]
|
||||
|
||||
if entry.aclgraph is None:
|
||||
if self.aclgraph_options.debug_log_enable:
|
||||
# Since we capture aclgraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug("Capturing a aclgraph on (%s,%s)",
|
||||
self.runtime_mode.name, entry.batch_descriptor)
|
||||
# validate that aclgraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.aclgraph_options.gc_disable:
|
||||
# during every model forward for piecewise aclgraph
|
||||
# mode, we will capture many pieces of aclgraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the aclgraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's aclgraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.aclgraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise aclgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other acl graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.aclgraph = aclgraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during acl graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for aclgraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}")
|
||||
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
Reference in New Issue
Block a user