diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 269767f..12a596c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -20,6 +20,7 @@ import gc import os import time +import types import weakref from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -321,6 +322,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.sampler = Sampler() + self.torchair_compiled_model = None # type: ignore + self.torchair_compiled_models = {} # type: ignore ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph @@ -713,7 +716,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata if self.torchair_graph_enabled and not with_prefill: - hidden_states = self.compile_model( + compiled_model = self._get_torchair_lazy_compiled_model( + padded_batch_size) + hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1190,7 +1195,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv, tuple), "kv_cache must be a tuple" torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[1]) - hidden_states = self.compile_model( + compiled_model = self._get_torchair_lazy_compiled_model( + num_tokens) + hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1264,30 +1271,59 @@ class NPUModelRunner(LoRAModelRunnerMixin): logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) - # adapter torch compile with npu_backend - if self.torchair_graph_enabled: - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore + def _get_torchair_lazy_compiled_model(self, batch_size: int): + if batch_size < 0 or batch_size > self.max_num_reqs: + raise ValueError( + f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}" + ) - patch_for_hcom() - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.compile_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - else: - self.compile_model = torchair.inference.cache_compile( - self.model.forward, + compiled_model = self.torchair_compiled_models.get( + batch_size + ) if self.use_cached_npu_graph else self.torchair_compiled_model + + if compiled_model: + return compiled_model + + import torchair # type: ignore + from torchair import patch_for_hcom # type: ignore + + patch_for_hcom() + config = torchair.CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + torch.npu.set_compile_mode(jit_compile=False) + if not self.use_cached_npu_graph: + npu_backend = torchair.get_npu_backend(compiler_config=config) + self.torchair_compiled_model = torch.compile( + self.model, + dynamic=True, + fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=npu_backend) + return self.torchair_compiled_model + else: + # Generate a new forward proxy code object to prevent the invalidation of + # compilation cache caused by dynamo retracing + forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}" + forward_fn = self.model.forward + code = forward_fn.__code__ + # Mark code object with a new proxy name + modified_code = code.replace(co_name=forward_proxy_name, ) + + modified_func = types.FunctionType(modified_code, + forward_fn.__globals__, + name=forward_proxy_name, + argdefs=forward_fn.__defaults__) + + self.model.__dict__[forward_proxy_name] = modified_func.__get__( + self.model, nn.Module) + self.torchair_compiled_models[ + batch_size] = torchair.inference.cache_compile( + self.model.__dict__[forward_proxy_name], dynamic=True, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, config=config, ge_cache=False) + return self.torchair_compiled_models[batch_size] def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """