diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index ad252b4..4c7cfa6 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -262,6 +262,40 @@ class TestAscendConfig(TestBase): } init_ascend_config(test_vllm_config) + # use_cached_kv_cache_bytes should not be enabled without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "use_cached_kv_cache_bytes": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # graph_batch_sizes should not be set without torchair graph mode + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": False, + "graph_batch_sizes": [1, 2, 4], + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + + # use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled + with self.assertRaises(RuntimeError): + test_vllm_config.additional_config = { + "torchair_graph_config": { + "enabled": True, + "use_cached_graph": False, + "use_cached_kv_cache_bytes": True, + }, + "refresh": True + } + init_ascend_config(test_vllm_config) + # graph_batch_sizes_init should not be enabled without torchair graph mode with self.assertRaises(RuntimeError): test_vllm_config.additional_config = { diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index 8aebb9d..b68bc31 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -49,6 +49,16 @@ class TestTorchairUtils(TestBase): self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), "Delete kv cache bytes cache dir failed") + def test_delete_torchair_cache_file_multiple_times(self): + utils.write_kv_cache_bytes_to_file(0, 100) + utils.delete_torchair_cache_file() + for i in range(5): + try: + utils.delete_torchair_cache_file() + except FileNotFoundError: + self.fail( + f"Unexpected FileNotFoundError on delete call #{i+2}") + @patch('vllm.ModelRegistry') def test_register_torchair_model(self, mock_model_registry): mock_registry = MagicMock() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 597ff5b..e46cd9a 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -73,6 +73,8 @@ class TorchairGraphConfig: self.mode = torchair_graph_config.get("mode", '') self.use_cached_graph = torchair_graph_config.get( "use_cached_graph", False) + self.use_cached_kv_cache_bytes = torchair_graph_config.get( + "use_cached_kv_cache_bytes", False) self.graph_batch_sizes = torchair_graph_config.get( "graph_batch_sizes", []) self.graph_batch_sizes_init = torchair_graph_config.get( @@ -99,6 +101,10 @@ class TorchairGraphConfig: raise RuntimeError( "use_cached_graph is valid only when Torchair graph mode is enabled" ) + if self.use_cached_kv_cache_bytes: + raise RuntimeError( + "use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled" + ) if self.graph_batch_sizes: raise RuntimeError( "graph_batch_sizes is valid only when Torchair graph mode is enabled" @@ -119,6 +125,10 @@ class TorchairGraphConfig: raise RuntimeError( "enable_kv_nz is valid only when Torchair graph mode is enabled" ) + if self.use_cached_kv_cache_bytes and not self.use_cached_graph: + raise RuntimeError( + "use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled" + ) class AscendSchedulerConfig: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index afc0e6b..57ace2b 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -28,6 +28,8 @@ from vllm.platforms import Platform, PlatformEnum from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, init_ascend_config) +from vllm_ascend.torchair.utils import (check_torchair_cache_exist, + delete_torchair_cache_file) from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p, update_aclgraph_sizes) @@ -170,6 +172,18 @@ class NPUPlatform(Platform): "Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE" ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension + # mismatches or configuration inconsistencies when users reuse cached computation graphs. Though + # this will increase graph compilation duration, it significantly enhances robustness and decreases + # graph launching time during inference. + if check_torchair_cache_exist( + ) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: + logger.warning( + "Torchair cache folder is deleted here to prevent runtime issues caused by dimension " + "mismatches or configuration inconsistencies when users reuse cached computation graphs. " + "In order to decrease torchair graph compilation time, users can enable both use_cached_graph " + "and use_cached_kv_cache_bytes in torchair_graph_config.") + delete_torchair_cache_file() if parallel_config.distributed_executor_backend == "ray": logger.warning( diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 036db47..30ef293 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -628,6 +628,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.running_in_graph = False # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 7d2f605..2b34f9b 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -35,8 +35,9 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.platform import NPUPlatform from vllm_ascend.torchair.utils import ( - TorchairCommonAttentionMetadata, check_torchair_cache_exist, - converting_weight_acl_format, register_torchair_model, torchair_ops_patch, + TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata, + check_torchair_cache_exist, converting_weight_acl_format, + register_torchair_model, torchair_ops_patch, torchair_quant_method_register, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, is_310p) @@ -52,6 +53,7 @@ class NPUTorchairModelRunner(NPUModelRunner): self.torchair_compiled_model = None # type: ignore self.torchair_compiled_models = {} # type: ignore self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph + self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes if ascend_config.torchair_graph_config.graph_batch_sizes_init: self.init_torchair_graph_batch_sizes() @@ -194,14 +196,20 @@ class NPUTorchairModelRunner(NPUModelRunner): graph_num = len(torchair_graph_batch_sizes) if self.use_cached_npu_graph and not check_torchair_cache_exist(): - # If caching is enabled but does not exist, we will compile the model twice. The first - # time is used to generate the cache, and the second time is used to load the cache to - # skip the overhead caused by Dynamo guard mechanism. + # If caching is enabled but does not exist (either + # use_cached_kv_cache_bytes is disabled or kv_cache_bytes are + # different), we will compile the model twice. The first time is + # used to generate the cache, and the second time is used to load the + # cache to skip the overhead caused by Dynamo guard mechanism. logger.info( - "Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.", + "Cache compilation for torchair graph is enabled. Now we compile graph to genetate" + " torchair cache, this usually takes %.1f~%.1f mins.", 0.5 * graph_num, 1.5 * graph_num) self._compile_torchair_graph(torchair_graph_batch_sizes) NPUPlatform.synchronize() + # Note: We reset dynamo and reload the compiled torchair cached computation graph below + # that was compiled above. This operation reduces graph launch time by 2-4ms and avoids + # runtime errors caused by configuration mismatches in graph mode. torch._dynamo.reset() self.torchair_compiled_models.clear() if self.use_cached_npu_graph: @@ -215,7 +223,7 @@ class NPUTorchairModelRunner(NPUModelRunner): 0.5 * graph_num, 1.5 * graph_num) self._compile_torchair_graph(torchair_graph_batch_sizes) - if self.new_kv_cache_bytes > 0: + if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0: write_kv_cache_bytes_to_file(torch.distributed.get_rank(), self.new_kv_cache_bytes) @@ -362,6 +370,7 @@ class NPUTorchairModelRunner(NPUModelRunner): self.model.__dict__[forward_proxy_name], dynamic=True, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + cache_dir=TORCHAIR_CACHE_DIR, config=config, ge_cache=False) return self.torchair_compiled_models[batch_size] diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py index 3488ac7..85f2fb4 100644 --- a/vllm_ascend/torchair/torchair_worker.py +++ b/vllm_ascend/torchair/torchair_worker.py @@ -17,9 +17,9 @@ import torch from vllm.logger import logger import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist, - check_torchair_cache_exist, delete_torchair_cache_file, read_kv_cache_bytes_from_file) from vllm_ascend.worker.worker_v1 import NPUWorker @@ -33,7 +33,9 @@ class NPUTorchairWorker(NPUWorker): available_kv_cache_memory = super().determine_available_memory() - if check_torchair_cache_exist() and check_kv_cache_bytes_cache_exist(): + if get_ascend_config( + ).torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist( + ): old_kv_cache_bytes = read_kv_cache_bytes_from_file( torch.distributed.get_rank()) if 0 < old_kv_cache_bytes <= available_kv_cache_memory: diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 8dd1e3f..563fda7 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -18,8 +18,8 @@ except ImportError: KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes" KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes" TORCHAIR_CACHE_PATH_NAME = ".torchair_cache" -TORCHAIR_CACHE_DIR = os.getenv( - 'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME)) +TORCHAIR_CACHE_DIR = os.path.join( + os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME) @dataclass @@ -111,8 +111,10 @@ def write_kv_cache_bytes_to_file(rank, kv_cache_bytes): def delete_torchair_cache_file(): torch_air_abs_path = _get_torchair_current_work_dir() - if os.path.exists(torch_air_abs_path): + try: shutil.rmtree(torch_air_abs_path) + except FileNotFoundError: + pass def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True): diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index e5d555b..e8f369f 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -20,7 +20,8 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ TorchairDeepSeekMTP -from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata +from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR, + TorchairCommonAttentionMetadata) from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable @@ -411,6 +412,7 @@ class MtpProposer: self.model.__dict__[forward_proxy_name], dynamic=True, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + cache_dir=TORCHAIR_CACHE_DIR, config=config, ge_cache=False) return self.torchair_compiled_models[batch_size]