diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 72ad5b3a..e3f6dd14 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -1020,6 +1020,8 @@ class TestNPUWorker(TestBase): # Verify _dummy_run call count and order (by size descending) expected_calls = [ unittest.mock.call(16), + unittest.mock.call(8), + unittest.mock.call(4), unittest.mock.call(1), ] worker.model_runner._dummy_run.assert_has_calls(expected_calls) @@ -1028,7 +1030,7 @@ class TestNPUWorker(TestBase): worker.model_runner.capture_model.assert_not_called() # Verify log output - self.assertEqual(mock_logger.info.call_count, 2) + self.assertEqual(mock_logger.info.call_count, 4) # Verify seed setting mock_seed_everything.assert_called_once_with(12345) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index deff99bb..1c706806 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -26,7 +26,6 @@ from torch._inductor.compile_fx import (graph_returns_tuple, from torch._inductor.decomposition import select_decomp_table from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface -from vllm.config.utils import Range from vllm_ascend.ascend_config import get_ascend_config @@ -47,13 +46,13 @@ def fusion_pass_compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range, + runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: def compile_inner(graph, example_inputs): current_pass_manager = compiler_config["graph_fusion_manager"] - graph = current_pass_manager(graph, compile_range) + graph = current_pass_manager(graph, runtime_shape) return graph decompositions = select_decomp_table() @@ -72,7 +71,7 @@ def npugraph_ex_compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range, + runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: # When currently using the FULL_DECODE_ONLY mode, @@ -125,14 +124,14 @@ class AscendCompiler(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - compile_range: Range, + runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: ascend_config = get_ascend_config() if ascend_config.enable_npugraph_ex: return npugraph_ex_compile(graph, example_inputs, compiler_config, - compile_range, key) + runtime_shape, key) else: return fusion_pass_compile(graph, example_inputs, compiler_config, - compile_range, key) + runtime_shape, key) diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 4e458dde..e311b260 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -17,7 +17,6 @@ # from torch import fx as fx -from vllm.compilation.inductor_pass import get_pass_context from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig @@ -33,13 +32,10 @@ class GraphFusionPassManager: def __init__(self): self.passes: list[VllmInductorPass] = [] - def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph: - compile_range = get_pass_context().compile_range - + def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph: for pass_ in self.passes: - if pass_.is_applicable_for_range(compile_range): + if pass_.is_applicable(runtime_shape): pass_(graph) - graph.recompile() return graph def add(self, pass_: VllmInductorPass): diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index eeaccd80..f929c1a4 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -20,7 +20,6 @@ import torch._inductor.pattern_matcher as pm from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig -from vllm.config.compilation import Range from vllm.logger import logger @@ -309,7 +308,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): logger.debug("Replaced %s patterns", self.matched_count) self.end_and_log() - def is_applicable_for_range(self, compile_range: Range) -> bool: + def is_applicable(self, runtime_shape: int | None = None) -> bool: """ Check if the pass is applicable for the current configuration. """ diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index ed90c7f8..f8355a15 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -22,7 +22,6 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.config.compilation import Range from vllm.logger import logger @@ -284,7 +283,7 @@ class QKNormRopeFusionPass(VllmInductorPass): pattern_idx += 1 self.end_and_log() - def is_applicable_for_range(self, compile_range: Range) -> bool: + def is_applicable(self, runtime_shape): """ Check if the pass is applicable for the current configuration. """ diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index a1037855..abdb631e 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -106,6 +106,20 @@ # Future Plan: # Remove this patch when vLLM merge the PR. # +# ** 7. File: platform/patch_compile_backend.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter` +# `vllm.compilation.piecewise_backend.PiecewiseBackend` +# Why: +# vllm removed the compile graph for general shape, which caused operator fusion to fail. +# This issue affects the performance of model inference on Ascend. +# How: +# recover the compiled graph for dynamic_shape in PiecewiseBackend. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/24252 +# Future Plan: +# Remove this patch when fix the problem. +# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 49840db3..cc33cde1 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -16,6 +16,7 @@ import os +import vllm_ascend.patch.platform.patch_compile_backend # noqa import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_ec_connector # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa diff --git a/vllm_ascend/patch/platform/patch_compile_backend.py b/vllm_ascend/patch/platform/patch_compile_backend.py new file mode 100644 index 00000000..af8ec53a --- /dev/null +++ b/vllm_ascend/patch/platform/patch_compile_backend.py @@ -0,0 +1,235 @@ +from collections.abc import Callable +from typing import Any + +import torch +import torch.fx as fx +import vllm.compilation.backends +import vllm.compilation.piecewise_backend +from torch._dispatch.python import enable_python_dispatcher +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.piecewise_backend import RangeEntry +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.utils import Range +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.import_utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + + NOTE: the order in `compile_submod_names` matters, because + it will be used to determine the order of the compiled piecewise + graphs. The first graph will handle logging, and the last graph + has some special cudagraph output handling. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + compile_submod_names: list[str], + vllm_config: VllmConfig, + vllm_backend: "VllmBackend", + ): + super().__init__(module) + from torch._guards import detect_fake_mode + + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + self.vllm_backend = vllm_backend + # When True, it annoyingly dumps the torch.fx.Graph on errors. + self.extra_traceback = False + + def run(self, *args): + # maybe instead just assert inputs are fake? + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + with self.fake_mode, enable_python_dispatcher(): + return super().run(*fake_args) + + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], + kwargs: dict[str, Any], + ) -> Any: + assert isinstance(target, str) + + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) + submod = self.fetch_attr(target) + + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + r1 = Range(start=1, end=max_num_batched_tokens) + compiled_graph_for_dynamic_shape = ( + self.vllm_backend.compiler_manager.compile( + submod, + args, + self.vllm_backend.inductor_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + compile_range=r1, + )) + + # Lazy import here to avoid circular import + from vllm.compilation.piecewise_backend import PiecewiseBackend + + piecewise_backend = PiecewiseBackend( + submod, + self.vllm_config, + index, + len(self.compile_submod_names), + sym_shape_indices, + compiled_graph_for_dynamic_shape, + self.vllm_backend, + ) + + if (self.compilation_config.cudagraph_mode. + has_piecewise_cudagraphs() and + not self.compilation_config.use_inductor_graph_partition): + # We're using Dynamo-based piecewise splitting, so we wrap + # the whole subgraph with a static graph wrapper. + from vllm.compilation.cuda_graph import CUDAGraphOptions + + # resolve the static graph wrapper class (e.g. CUDAGraphWrapper + # class) as platform dependent. + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + + # Always assign PIECEWISE runtime mode to the + # CUDAGraphWrapper for piecewise_backend, to distinguish + # it from the FULL cudagraph runtime mode, no matter it + # is wrapped on a full or piecewise fx graph. + self.module.__dict__[target] = static_graph_wrapper_class( + runnable=piecewise_backend, + vllm_config=self.vllm_config, + runtime_mode=CUDAGraphMode.PIECEWISE, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=piecewise_backend.is_first_graph, + gc_disable=not piecewise_backend.is_first_graph, + weak_ref_output=piecewise_backend.is_last_graph, + ), + ) + else: + self.module.__dict__[target] = piecewise_backend + + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + +class AscendPiecewiseBackend: + + def __init__( + self, + graph: fx.GraphModule, + vllm_config: VllmConfig, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, + ): + """ + The backend for piecewise compilation. + It mainly handles the compilation of static shapes and + dispatching based on runtime shape. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + """ + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 + + self.is_full_graph = total_piecewise_compiles == 1 + self.is_encoder_compilation = vllm_backend.is_encoder + + self.compile_ranges = self.compilation_config.get_compile_ranges() + if self.is_encoder_compilation: + # For encoder compilation we use the max int32 value + # to set the upper bound of the compile ranges + max_int32 = 2**31 - 1 + last_compile_range = self.compile_ranges[-1] + assert (last_compile_range.end == + vllm_config.scheduler_config.max_num_batched_tokens) + self.compile_ranges[-1] = Range(start=last_compile_range.start, + end=max_int32) + + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) + + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) + + self.sym_shape_indices = sym_shape_indices + + # the entries for ranges that we need to either + self.range_entries: dict[Range, RangeEntry] = {} + + # to_be_compiled_ranges tracks the remaining ranges to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) + + # We only keep compilation management inside this class directly. + for size in self.compile_sizes: + range = Range(start=size, end=size) + if range not in self.compile_ranges: + self.range_entries[range] = RangeEntry(compile_range=range, ) + self.to_be_compiled_ranges.add(range) + + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry(compile_range=range, ) + + def _find_range_for_shape(self, runtime_shape: int) -> Range | None: + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. + if runtime_shape in self.compile_sizes: + return self.range_entries[Range(start=runtime_shape, + end=runtime_shape)] + else: + for range in self.compile_ranges: + if runtime_shape in range: + return self.range_entries[range] + return None + + def __call__(self, *args) -> Any: + runtime_shape = args[self.sym_shape_indices[0]] + range_entry = self._find_range_for_shape(runtime_shape) + + assert range_entry is not None, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]") + + return self.compiled_graph_for_general_shape(*args) + + +vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter +vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__ +vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__ diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index ea69c3f8..6e437208 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -27,7 +27,7 @@ import torch_npu import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.profiler import dynamic_profile as dp -from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized @@ -372,25 +372,11 @@ class NPUWorker(WorkerBase): self.model_runner.eplb_warmup() warmup_sizes = (self.vllm_config.compilation_config.compile_sizes or []).copy() - cg_capture_sizes: list[int] = [] - if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: - cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - cg_capture_sizes = [] if cg_sizes is None else cg_sizes + if not self.model_config.enforce_eager: warmup_sizes = [ - x for x in warmup_sizes if x not in cg_capture_sizes + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes ] - - compile_ranges = self.vllm_config.compilation_config.get_compile_ranges( - ) - # For each compile_range, if none of the batch sizes - # in warmup_sizes or cudagraph_capture_sizes are in the range, - # add the end of the range to ensure compilation/warmup. - all_sizes = set(cg_capture_sizes) - all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) - for compile_range in compile_ranges: - if not any(x in compile_range for x in all_sizes): - warmup_sizes.append(compile_range.end) - for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size)