[BugFix][Fusion] Fix graph fusion failure problem (#5676)
Currently, the vllm pull request
(https://github.com/vllm-project/vllm/pull/24252) is causing operator
fusion to fail. This issue was previously fixed by patching the backend.
The root cause has been identified, and the problem can be resolved with
this pull request.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from torch._inductor.compile_fx import (graph_returns_tuple,
|
|||||||
from torch._inductor.decomposition import select_decomp_table
|
from torch._inductor.decomposition import select_decomp_table
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from vllm.compilation.compiler_interface import CompilerInterface
|
from vllm.compilation.compiler_interface import CompilerInterface
|
||||||
|
from vllm.config.utils import Range
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
||||||
@@ -47,13 +48,13 @@ def fusion_pass_compile(
|
|||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: list[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
compile_range: Range,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
|
||||||
def compile_inner(graph, example_inputs):
|
def compile_inner(graph, example_inputs):
|
||||||
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
|
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
|
||||||
graph = current_pass_manager(graph, runtime_shape)
|
graph = current_pass_manager(graph)
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
decompositions = select_decomp_table()
|
decompositions = select_decomp_table()
|
||||||
@@ -72,7 +73,7 @@ def npugraph_ex_compile(
|
|||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: list[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
compile_range: Range,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
# When currently using the FULL_DECODE_ONLY mode,
|
# When currently using the FULL_DECODE_ONLY mode,
|
||||||
@@ -125,14 +126,14 @@ class AscendCompiler(CompilerInterface):
|
|||||||
graph: fx.GraphModule,
|
graph: fx.GraphModule,
|
||||||
example_inputs: list[Any],
|
example_inputs: list[Any],
|
||||||
compiler_config: dict[str, Any],
|
compiler_config: dict[str, Any],
|
||||||
runtime_shape: Optional[int] = None,
|
compile_range: Range,
|
||||||
key: Optional[str] = None,
|
key: Optional[str] = None,
|
||||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
if ascend_config.enable_npugraph_ex:
|
if ascend_config.enable_npugraph_ex:
|
||||||
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
||||||
runtime_shape, key)
|
compile_range, key)
|
||||||
else:
|
else:
|
||||||
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
||||||
runtime_shape, key)
|
compile_range, key)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
from vllm.compilation.inductor_pass import get_pass_context
|
||||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
@@ -32,10 +33,13 @@ class GraphFusionPassManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.passes: list[VllmInductorPass] = []
|
self.passes: list[VllmInductorPass] = []
|
||||||
|
|
||||||
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
|
def __call__(self, graph: fx.Graph) -> fx.Graph:
|
||||||
|
compile_range = get_pass_context().compile_range
|
||||||
|
|
||||||
for pass_ in self.passes:
|
for pass_ in self.passes:
|
||||||
if pass_.is_applicable(runtime_shape):
|
if pass_.is_applicable_for_range(compile_range):
|
||||||
pass_(graph)
|
pass_(graph)
|
||||||
|
graph.recompile()
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
def add(self, pass_: VllmInductorPass):
|
def add(self, pass_: VllmInductorPass):
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import torch._inductor.pattern_matcher as pm
|
|||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.compilation import Range
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -308,7 +309,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
|||||||
logger.debug("Replaced %s patterns", self.matched_count)
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
self.end_and_log()
|
self.end_and_log()
|
||||||
|
|
||||||
def is_applicable(self, runtime_shape: int | None = None) -> bool:
|
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the pass is applicable for the current configuration.
|
Check if the pass is applicable for the current configuration.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
|
from vllm.config.compilation import Range
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -283,7 +284,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
|||||||
pattern_idx += 1
|
pattern_idx += 1
|
||||||
self.end_and_log()
|
self.end_and_log()
|
||||||
|
|
||||||
def is_applicable(self, runtime_shape):
|
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the pass is applicable for the current configuration.
|
Check if the pass is applicable for the current configuration.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -106,20 +106,6 @@
|
|||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when vLLM merge the PR.
|
# Remove this patch when vLLM merge the PR.
|
||||||
#
|
#
|
||||||
# ** 7. File: platform/patch_compile_backend.py**
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter`
|
|
||||||
# `vllm.compilation.piecewise_backend.PiecewiseBackend`
|
|
||||||
# Why:
|
|
||||||
# vllm removed the compile graph for general shape, which caused operator fusion to fail.
|
|
||||||
# This issue affects the performance of model inference on Ascend.
|
|
||||||
# How:
|
|
||||||
# recover the compiled graph for dynamic_shape in PiecewiseBackend.
|
|
||||||
# Related PR (if no, explain why):
|
|
||||||
# https://github.com/vllm-project/vllm/pull/24252
|
|
||||||
# Future Plan:
|
|
||||||
# Remove this patch when fix the problem.
|
|
||||||
#
|
|
||||||
# * Worker Patch:
|
# * Worker Patch:
|
||||||
# ===============
|
# ===============
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import vllm_ascend.patch.platform.patch_compile_backend # noqa
|
|
||||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.platform.patch_ec_connector # noqa
|
import vllm_ascend.patch.platform.patch_ec_connector # noqa
|
||||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||||
|
|||||||
@@ -1,235 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.fx as fx
|
|
||||||
import vllm.compilation.backends
|
|
||||||
import vllm.compilation.piecewise_backend
|
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
|
||||||
from vllm.compilation.piecewise_backend import RangeEntry
|
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
|
||||||
from vllm.config.utils import Range
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|
||||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
|
||||||
It runs the given graph with fake inputs, and compile some
|
|
||||||
submodules specified by `compile_submod_names` with the given
|
|
||||||
compilation configs.
|
|
||||||
|
|
||||||
NOTE: the order in `compile_submod_names` matters, because
|
|
||||||
it will be used to determine the order of the compiled piecewise
|
|
||||||
graphs. The first graph will handle logging, and the last graph
|
|
||||||
has some special cudagraph output handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
module: torch.fx.GraphModule,
|
|
||||||
compile_submod_names: list[str],
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
vllm_backend: "VllmBackend",
|
|
||||||
):
|
|
||||||
super().__init__(module)
|
|
||||||
from torch._guards import detect_fake_mode
|
|
||||||
|
|
||||||
self.fake_mode = detect_fake_mode()
|
|
||||||
self.compile_submod_names = compile_submod_names
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.vllm_backend = vllm_backend
|
|
||||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
|
||||||
self.extra_traceback = False
|
|
||||||
|
|
||||||
def run(self, *args):
|
|
||||||
# maybe instead just assert inputs are fake?
|
|
||||||
fake_args = [
|
|
||||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
|
||||||
for t in args
|
|
||||||
]
|
|
||||||
with self.fake_mode, enable_python_dispatcher():
|
|
||||||
return super().run(*fake_args)
|
|
||||||
|
|
||||||
def call_module(
|
|
||||||
self,
|
|
||||||
target: torch.fx.node.Target,
|
|
||||||
args: tuple[torch.fx.node.Argument, ...],
|
|
||||||
kwargs: dict[str, Any],
|
|
||||||
) -> Any:
|
|
||||||
assert isinstance(target, str)
|
|
||||||
|
|
||||||
output = super().call_module(target, args, kwargs)
|
|
||||||
|
|
||||||
if target in self.compile_submod_names:
|
|
||||||
index = self.compile_submod_names.index(target)
|
|
||||||
submod = self.fetch_attr(target)
|
|
||||||
|
|
||||||
sym_shape_indices = [
|
|
||||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
|
||||||
]
|
|
||||||
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
|
||||||
r1 = Range(start=1, end=max_num_batched_tokens)
|
|
||||||
compiled_graph_for_dynamic_shape = (
|
|
||||||
self.vllm_backend.compiler_manager.compile(
|
|
||||||
submod,
|
|
||||||
args,
|
|
||||||
self.vllm_backend.inductor_config,
|
|
||||||
self.compilation_config,
|
|
||||||
graph_index=index,
|
|
||||||
num_graphs=len(self.compile_submod_names),
|
|
||||||
compile_range=r1,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Lazy import here to avoid circular import
|
|
||||||
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
|
||||||
|
|
||||||
piecewise_backend = PiecewiseBackend(
|
|
||||||
submod,
|
|
||||||
self.vllm_config,
|
|
||||||
index,
|
|
||||||
len(self.compile_submod_names),
|
|
||||||
sym_shape_indices,
|
|
||||||
compiled_graph_for_dynamic_shape,
|
|
||||||
self.vllm_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (self.compilation_config.cudagraph_mode.
|
|
||||||
has_piecewise_cudagraphs() and
|
|
||||||
not self.compilation_config.use_inductor_graph_partition):
|
|
||||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
|
||||||
# the whole subgraph with a static graph wrapper.
|
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
|
||||||
|
|
||||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
|
||||||
# class) as platform dependent.
|
|
||||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
|
||||||
current_platform.get_static_graph_wrapper_cls())
|
|
||||||
|
|
||||||
# Always assign PIECEWISE runtime mode to the
|
|
||||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
|
||||||
# it from the FULL cudagraph runtime mode, no matter it
|
|
||||||
# is wrapped on a full or piecewise fx graph.
|
|
||||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
|
||||||
runnable=piecewise_backend,
|
|
||||||
vllm_config=self.vllm_config,
|
|
||||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
|
||||||
cudagraph_options=CUDAGraphOptions(
|
|
||||||
debug_log_enable=piecewise_backend.is_first_graph,
|
|
||||||
gc_disable=not piecewise_backend.is_first_graph,
|
|
||||||
weak_ref_output=piecewise_backend.is_last_graph,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.module.__dict__[target] = piecewise_backend
|
|
||||||
|
|
||||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class AscendPiecewiseBackend:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
graph: fx.GraphModule,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
piecewise_compile_index: int,
|
|
||||||
total_piecewise_compiles: int,
|
|
||||||
sym_shape_indices: list[int],
|
|
||||||
compiled_graph_for_general_shape: Callable,
|
|
||||||
vllm_backend: VllmBackend,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
The backend for piecewise compilation.
|
|
||||||
It mainly handles the compilation of static shapes and
|
|
||||||
dispatching based on runtime shape.
|
|
||||||
|
|
||||||
We will compile `self.graph` once for the general shape,
|
|
||||||
and then compile for different shapes specified in
|
|
||||||
`compilation_config.compile_sizes`.
|
|
||||||
"""
|
|
||||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape
|
|
||||||
self.graph = graph
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
|
||||||
self.piecewise_compile_index = piecewise_compile_index
|
|
||||||
self.total_piecewise_compiles = total_piecewise_compiles
|
|
||||||
self.vllm_backend = vllm_backend
|
|
||||||
|
|
||||||
self.is_first_graph = piecewise_compile_index == 0
|
|
||||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
|
||||||
|
|
||||||
self.is_full_graph = total_piecewise_compiles == 1
|
|
||||||
self.is_encoder_compilation = vllm_backend.is_encoder
|
|
||||||
|
|
||||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
|
||||||
if self.is_encoder_compilation:
|
|
||||||
# For encoder compilation we use the max int32 value
|
|
||||||
# to set the upper bound of the compile ranges
|
|
||||||
max_int32 = 2**31 - 1
|
|
||||||
last_compile_range = self.compile_ranges[-1]
|
|
||||||
assert (last_compile_range.end ==
|
|
||||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
|
||||||
self.compile_ranges[-1] = Range(start=last_compile_range.start,
|
|
||||||
end=max_int32)
|
|
||||||
|
|
||||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
|
||||||
logger.debug_once(log_string)
|
|
||||||
|
|
||||||
self.compile_sizes = self.compilation_config.compile_sizes
|
|
||||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
|
||||||
logger.debug_once(log_string)
|
|
||||||
|
|
||||||
self.sym_shape_indices = sym_shape_indices
|
|
||||||
|
|
||||||
# the entries for ranges that we need to either
|
|
||||||
self.range_entries: dict[Range, RangeEntry] = {}
|
|
||||||
|
|
||||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
|
||||||
# and updates during the compilation process, so we need to copy it
|
|
||||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
|
||||||
|
|
||||||
# We only keep compilation management inside this class directly.
|
|
||||||
for size in self.compile_sizes:
|
|
||||||
range = Range(start=size, end=size)
|
|
||||||
if range not in self.compile_ranges:
|
|
||||||
self.range_entries[range] = RangeEntry(compile_range=range, )
|
|
||||||
self.to_be_compiled_ranges.add(range)
|
|
||||||
|
|
||||||
for range in self.compile_ranges:
|
|
||||||
self.range_entries[range] = RangeEntry(compile_range=range, )
|
|
||||||
|
|
||||||
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
|
|
||||||
# First we try to find the range entry for the concrete compile size
|
|
||||||
# If not found, we search for the range entry
|
|
||||||
# that contains the runtime shape.
|
|
||||||
if runtime_shape in self.compile_sizes:
|
|
||||||
return self.range_entries[Range(start=runtime_shape,
|
|
||||||
end=runtime_shape)]
|
|
||||||
else:
|
|
||||||
for range in self.compile_ranges:
|
|
||||||
if runtime_shape in range:
|
|
||||||
return self.range_entries[range]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __call__(self, *args) -> Any:
|
|
||||||
runtime_shape = args[self.sym_shape_indices[0]]
|
|
||||||
range_entry = self._find_range_for_shape(runtime_shape)
|
|
||||||
|
|
||||||
assert range_entry is not None, (
|
|
||||||
f"Shape out of considered range: {runtime_shape} "
|
|
||||||
"[1, max_num_batched_tokens]")
|
|
||||||
|
|
||||||
return self.compiled_graph_for_general_shape(*args)
|
|
||||||
|
|
||||||
|
|
||||||
vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter
|
|
||||||
vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__
|
|
||||||
vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__
|
|
||||||
@@ -28,7 +28,7 @@ import torch_npu
|
|||||||
import vllm.envs as envs_vllm
|
import vllm.envs as envs_vllm
|
||||||
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
||||||
from torch_npu.profiler import dynamic_profile as dp
|
from torch_npu.profiler import dynamic_profile as dp
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||||
@@ -381,10 +381,25 @@ class NPUWorker(WorkerBase):
|
|||||||
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
|
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
|
||||||
or []).copy()
|
or []).copy()
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
|
cg_capture_sizes: list[int] = []
|
||||||
|
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
||||||
|
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||||
|
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
|
||||||
warmup_sizes = [
|
warmup_sizes = [
|
||||||
x for x in warmup_sizes if x not in
|
x for x in warmup_sizes if x not in cg_capture_sizes
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges(
|
||||||
|
)
|
||||||
|
# For each compile_range, if none of the batch sizes
|
||||||
|
# in warmup_sizes or cudagraph_capture_sizes are in the range,
|
||||||
|
# add the end of the range to ensure compilation/warmup.
|
||||||
|
all_sizes = set(cg_capture_sizes)
|
||||||
|
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
|
||||||
|
for compile_range in compile_ranges:
|
||||||
|
if not any(x in compile_range for x in all_sizes):
|
||||||
|
warmup_sizes.append(compile_range.end)
|
||||||
|
|
||||||
for size in sorted(warmup_sizes, reverse=True):
|
for size in sorted(warmup_sizes, reverse=True):
|
||||||
logger.info("Compile and warming up model for size %d", size)
|
logger.info("Compile and warming up model for size %d", size)
|
||||||
self.model_runner._dummy_run(size)
|
self.model_runner._dummy_run(size)
|
||||||
|
|||||||
Reference in New Issue
Block a user