From c7f1c59911027953147c9d9495457568fe2216c8 Mon Sep 17 00:00:00 2001 From: NeverRaR <44917563+NeverRaR@users.noreply.github.com> Date: Fri, 6 Jun 2025 20:17:51 +0800 Subject: [PATCH] feat: support compile multiple batch graph (#1085) ### What this PR does / why we need it? support compile multiple batch graph with different code object to avoid cache invalidation ### How was this patch tested? ``` export VLLM_ENABLE_MC2=0 export VLLM_USE_V1=1 export TASK_QUEUE_ENABLE=1 source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/nnal/atb/set_env.sh nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \ --quantization ascend \ --served-model-name auto \ --trust-remote-code \ --distributed-executor-backend=mp \ --port 8006 \ -tp=8 \ -dp=2 \ --no-enforce-eager \ --max-num-seqs 24 \ --max-model-len 32768 \ --max-num-batched-tokens 32768 \ --block-size 128 \ --no-enable-prefix-caching \ --additional-config '{"torchair_graph_config": {"enabled": true,"use_cached_graph": true,"graph_batch_sizes": [8,16,24]},"ascend_scheduler_config": {"enabled":true,"chunked_prefill_enabled":false},"expert_tensor_parallel_size":16}' \ --gpu-memory-utilization 0.95 &> run.log & disown ``` Signed-off-by: boying <897013703@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 78 +++++++++++++++++++-------- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 269767f..12a596c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -20,6 +20,7 @@ import gc import os import time +import types import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.sampler = Sampler() + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph @@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata if self.torchair_graph_enabled and not with_prefill: - hidden_states = self.compile_model( + compiled_model = self._get_torchair_lazy_compiled_model( + padded_batch_size) + hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv, tuple), "kv_cache must be a tuple" torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[1]) - hidden_states = self.compile_model( + compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1264,30 +1271,59 @@ class NPUModelRunner(LoRAModelRunnerMixin): logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) - # adapter torch compile with npu_backend - if self.torchair_graph_enabled: - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.max_num_reqs: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}" + ) - patch_for_hcom() - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.compile_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - else: - self.compile_model = torchair.inference.cache_compile( - self.model.forward, + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], dynamic=True, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, config=config, ge_cache=False) + return self.torchair_compiled_models[batch_size] def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """