diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index b3b8ecbe..f42f83d1 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -23,7 +23,11 @@ import torch import torch_npu from vllm.config import VllmConfig from vllm.forward_context import get_forward_context +from vllm.logger import logger +from vllm_ascend.platform import NPUPlatform +from vllm_ascend.torchair.utils import (check_torchair_cache_exist, + write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -37,6 +41,7 @@ class NPUTorchairModelRunner(NPUModelRunner): def _get_forward_metadata_across_dp_and_pad( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + """Override from NPUModelRunner to pad num_tokens""" if self.dp_size == 1: if not with_prefill: maybe_padded_num_tokens = self.select_torchair_padded_batch_size( @@ -118,3 +123,49 @@ class NPUTorchairModelRunner(NPUModelRunner): def _convert_torch_format(self, kv_cache): kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND) return kv_cache + + def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None: + # Trigger torchair graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens, is_torchair_compile=True) + self._dummy_run(num_tokens, is_torchair_compile=True) + logger.info("Batchsize %d is compiled successfully: %d/%d.", + num_tokens, idx + 1, len(torchair_graph_batch_sizes)) + + def _capture_model(self): + """Override from NPUModelRunner to use torchair graph capture.""" + # TODO(NeverRaR): Calling graph_capture(device=self.device) in + # torchair graph capture can cause some issues, so now we just + # temporarily split the codepath for the two different graph patterns. + torchair_graph_batch_sizes = self.torchair_graph_batch_sizes + graph_num = len(torchair_graph_batch_sizes) + + if self.use_cached_npu_graph and not check_torchair_cache_exist(): + # If caching is enabled but does not exist, we will compile the model twice. The first + # time is used to generate the cache, and the second time is used to load the cache to + # skip the overhead caused by Dynamo guard mechanism. + logger.info( + "Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.", + 0.5 * graph_num, 1.5 * graph_num) + self._compile_torchair_graph(torchair_graph_batch_sizes) + NPUPlatform.synchronize() + torch._dynamo.reset() + self.torchair_compiled_models.clear() + if self.use_cached_npu_graph: + logger.info( + "Loading torchair graph cache, this usually takes %.1f~%.1f mins.", + 0.3 * graph_num, 0.5 * graph_num) + self._compile_torchair_graph(torchair_graph_batch_sizes) + else: + logger.info( + "Capturing torchair graph, this usually takes %.1f~%.1f mins.", + 0.5 * graph_num, 1.5 * graph_num) + self._compile_torchair_graph(torchair_graph_batch_sizes) + + if self.new_kv_cache_bytes > 0: + write_kv_cache_bytes_to_file(torch.distributed.get_rank(), + self.new_kv_cache_bytes) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ae1cff3f..594649c6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -82,8 +82,6 @@ from vllm_ascend.attention.mla_v1 import AscendMLAMetadata from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler -from vllm_ascend.torchair.utils import (check_torchair_cache_exist, - write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, maybe_converting_weight_acl_format, @@ -2323,67 +2321,27 @@ class NPUModelRunner(LoRAModelRunnerMixin): return kv_cache_spec - def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None: - # Trigger torchair graph capture for specific shapes. + def _capture_model(self): + if not self.use_aclgraph: + logger.info("Skipping NPU graph capture for eager mode.") + return + # Trigger ACL graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens, is_torchair_compile=True) - self._dummy_run(num_tokens, is_torchair_compile=True) - logger.info("Batchsize %d is compiled successfully: %d/%d.", - num_tokens, idx + 1, len(torchair_graph_batch_sizes)) + # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode + with graph_capture(device=self.device): + for num_tokens in reversed(self.aclgraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(num_tokens) + self._dummy_run(num_tokens) def capture_model(self) -> None: start_time = time.perf_counter() start_free_npu_memory = torch.npu.mem_get_info()[0] - # TODO(NeverRaR): Calling graph_capture(device=self.device) in - # torchair graph capture can cause some issues, so now we just - # temporarily split the codepath for the two different graph patterns. - if self.torchair_graph_enabled: - torchair_graph_batch_sizes = self.torchair_graph_batch_sizes - graph_num = len(torchair_graph_batch_sizes) - if self.use_cached_npu_graph and not check_torchair_cache_exist(): - # If caching is enabled but does not exist, we will compile the model twice. The first - # time is used to generate the cache, and the second time is used to load the cache to - # skip the overhead caused by Dynamo guard mechanism. - logger.info( - "Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.", - 0.5 * graph_num, 1.5 * graph_num) - self._compile_torchair_graph(torchair_graph_batch_sizes) - NPUPlatform.synchronize() - torch._dynamo.reset() - self.torchair_compiled_models.clear() - if self.use_cached_npu_graph: - logger.info( - "Loading torchair graph cache, this usually takes %.1f~%.1f mins.", - 0.3 * graph_num, 0.5 * graph_num) - self._compile_torchair_graph(torchair_graph_batch_sizes) - else: - logger.info( - "Capturing torchair graph, this usually takes %.1f~%.1f mins.", - 0.5 * graph_num, 1.5 * graph_num) - self._compile_torchair_graph(torchair_graph_batch_sizes) + self._capture_model() - if self.new_kv_cache_bytes > 0: - write_kv_cache_bytes_to_file(torch.distributed.get_rank(), - self.new_kv_cache_bytes) - elif self.use_aclgraph: - # Trigger ACL graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode - with graph_capture(device=self.device): - for num_tokens in reversed(self.aclgraph_batch_sizes): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) - else: - logger.info("Skipping NPU graph capture for eager mode.") - return end_time = time.perf_counter() end_free_npu_memory = torch.npu.mem_get_info()[0] elapsed_time = end_time - start_time