[bugfix] fix torchair runtime error caused by configuration mismtaches and file missing (#2532)
### What this PR does / why we need it?
This PR ports #2312 #2506 #2531 to main branch.
Original implementation of torchair caching forces users to make
everything prepared, fix all the configuration and enable
`use_cached_npu_graph`, and it might cause some problems confusing to
understand and tackle for users. It is better to compile the graph twice
instead of reusing the old kvcaches and cached torchair graph. And the
extra duration time is acceptable. Additionally, this pr fixes a
recompilation problem of torchair graph mode caused by
`running_in_graph` variable in `AscendMLATorchairImpl`.
### Does this PR introduce _any_ user-facing change?
If users want to enabling torchair.cache_compile with high compilation
speed, it is recommended to enable both `use_cached_kv_cache_bytes` and
`use_cached_graph` in `torchair_graph_config`. Without
`use_cached_kv_cache_bytes`, we'll compile torchair computation graph
twice to avoid runtime error caused by configuration mismtaches (the
second compilation will be much faster). Additionally, we've made a
change to how the TORCHAIR_CACHE_HOME enviroment variable is utilized to
enhance safety and prevent accidental file deletion by adding a suffix
directory.
### How was this patch tested?
CI and e2e vllm serving pass.
- vLLM version: v0.10.1.1
- vLLM main:
70549c1245
---------
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -262,6 +262,40 @@ class TestAscendConfig(TestBase):
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_kv_cache_bytes should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"use_cached_kv_cache_bytes": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# graph_batch_sizes should not be set without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"use_cached_kv_cache_bytes": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# graph_batch_sizes_init should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
|
||||
@@ -49,6 +49,16 @@ class TestTorchairUtils(TestBase):
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_delete_torchair_cache_file_multiple_times(self):
|
||||
utils.write_kv_cache_bytes_to_file(0, 100)
|
||||
utils.delete_torchair_cache_file()
|
||||
for i in range(5):
|
||||
try:
|
||||
utils.delete_torchair_cache_file()
|
||||
except FileNotFoundError:
|
||||
self.fail(
|
||||
f"Unexpected FileNotFoundError on delete call #{i+2}")
|
||||
|
||||
@patch('vllm.ModelRegistry')
|
||||
def test_register_torchair_model(self, mock_model_registry):
|
||||
mock_registry = MagicMock()
|
||||
|
||||
@@ -73,6 +73,8 @@ class TorchairGraphConfig:
|
||||
self.mode = torchair_graph_config.get("mode", '')
|
||||
self.use_cached_graph = torchair_graph_config.get(
|
||||
"use_cached_graph", False)
|
||||
self.use_cached_kv_cache_bytes = torchair_graph_config.get(
|
||||
"use_cached_kv_cache_bytes", False)
|
||||
self.graph_batch_sizes = torchair_graph_config.get(
|
||||
"graph_batch_sizes", [])
|
||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||
@@ -99,6 +101,10 @@ class TorchairGraphConfig:
|
||||
raise RuntimeError(
|
||||
"use_cached_graph is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.graph_batch_sizes:
|
||||
raise RuntimeError(
|
||||
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
|
||||
@@ -119,6 +125,10 @@ class TorchairGraphConfig:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||
raise RuntimeError(
|
||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||
)
|
||||
|
||||
|
||||
class AscendSchedulerConfig:
|
||||
|
||||
@@ -28,6 +28,8 @@ from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
|
||||
update_aclgraph_sizes)
|
||||
|
||||
@@ -170,6 +172,18 @@ class NPUPlatform(Platform):
|
||||
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
|
||||
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
|
||||
# this will increase graph compilation duration, it significantly enhances robustness and decreases
|
||||
# graph launching time during inference.
|
||||
if check_torchair_cache_exist(
|
||||
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
|
||||
logger.warning(
|
||||
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
|
||||
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
|
||||
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
|
||||
"and use_cached_kv_cache_bytes in torchair_graph_config.")
|
||||
delete_torchair_cache_file()
|
||||
|
||||
if parallel_config.distributed_executor_backend == "ray":
|
||||
logger.warning(
|
||||
|
||||
@@ -628,6 +628,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.running_in_graph = False
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = get_current_vllm_config().speculative_config
|
||||
|
||||
@@ -35,8 +35,9 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (
|
||||
TorchairCommonAttentionMetadata, check_torchair_cache_exist,
|
||||
converting_weight_acl_format, register_torchair_model, torchair_ops_patch,
|
||||
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist, converting_weight_acl_format,
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p)
|
||||
@@ -52,6 +53,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
@@ -194,14 +196,20 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist, we will compile the model twice. The first
|
||||
# time is used to generate the cache, and the second time is used to load the cache to
|
||||
# skip the overhead caused by Dynamo guard mechanism.
|
||||
# If caching is enabled but does not exist (either
|
||||
# use_cached_kv_cache_bytes is disabled or kv_cache_bytes are
|
||||
# different), we will compile the model twice. The first time is
|
||||
# used to generate the cache, and the second time is used to load the
|
||||
# cache to skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
"Cache compilation for torchair graph is enabled. Now we compile graph to genetate"
|
||||
" torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
# Note: We reset dynamo and reload the compiled torchair cached computation graph below
|
||||
# that was compiled above. This operation reduces graph launch time by 2-4ms and avoids
|
||||
# runtime errors caused by configuration mismatches in graph mode.
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
@@ -215,7 +223,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
|
||||
if self.new_kv_cache_bytes > 0:
|
||||
if self.use_cached_kv_cache_bytes and self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
@@ -362,6 +370,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
@@ -17,9 +17,9 @@ import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.torchair.torchair_model_runner import NPUTorchairModelRunner
|
||||
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
|
||||
check_torchair_cache_exist,
|
||||
delete_torchair_cache_file,
|
||||
read_kv_cache_bytes_from_file)
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
@@ -33,7 +33,9 @@ class NPUTorchairWorker(NPUWorker):
|
||||
|
||||
available_kv_cache_memory = super().determine_available_memory()
|
||||
|
||||
if check_torchair_cache_exist() and check_kv_cache_bytes_cache_exist():
|
||||
if get_ascend_config(
|
||||
).torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist(
|
||||
):
|
||||
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
|
||||
torch.distributed.get_rank())
|
||||
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
|
||||
|
||||
@@ -18,8 +18,8 @@ except ImportError:
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
||||
TORCHAIR_CACHE_DIR = os.getenv(
|
||||
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
|
||||
TORCHAIR_CACHE_DIR = os.path.join(
|
||||
os.getenv('TORCHAIR_CACHE_HOME', os.getcwd()), TORCHAIR_CACHE_PATH_NAME)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -111,8 +111,10 @@ def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
||||
|
||||
def delete_torchair_cache_file():
|
||||
torch_air_abs_path = _get_torchair_current_work_dir()
|
||||
if os.path.exists(torch_air_abs_path):
|
||||
try:
|
||||
shutil.rmtree(torch_air_abs_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
||||
|
||||
@@ -20,7 +20,8 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||
TorchairDeepSeekMTP
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
||||
TorchairCommonAttentionMetadata)
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
|
||||
@@ -411,6 +412,7 @@ class MtpProposer:
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
Reference in New Issue
Block a user