feat: add kv cache memory cache and skip dynamo guard (#1549)

### What this PR does / why we need it?

1、Sometimes loading torchair cache will fail because of the floating of
npu memory, so this pr add a new cache to save the old kv cache bytes to
avoid the possible crash while loading the torchair graph cache.
2、When caching is enabled and does not exist, the first compilation
introduces the overhead of Dynamo Gurad. So in this case, we will
compile them directly twice to skip them (This will bring 3-4 ms of tpot
optimization)

### Does this PR introduce _any_ user-facing change?
Add a new env `VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE` to
control kv cache floating tolerance

### How was this patch tested?

- vLLM version: v0.9.1
- vLLM main:
1fd471e957

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-07-07 22:37:14 +08:00
committed by GitHub
parent df84cceca8
commit 71de52d3a9
5 changed files with 182 additions and 24 deletions

View File

@@ -280,6 +280,27 @@ class TestUtils(TestBase):
3,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
def test_get_torchair_current_work_dir(self):
cache_dir = utils.TORCHAIR_CACHE_DIR
work_dir = utils.get_torchair_current_work_dir()
self.assertEqual(cache_dir, work_dir)
work_dir = utils.get_torchair_current_work_dir("test")
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
def test_torchair_cache_dir(self):
utils.write_kv_cache_bytes_to_file(0, 100)
self.assertTrue(utils.check_torchair_cache_exist(),
"Create torchair cache dir failed")
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
"Create kv cache bytes cache dir failed")
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
self.assertEqual(100, kv_cache_bytes)
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
class TestProfileExecuteDuration(unittest.TestCase):

View File

@@ -121,6 +121,12 @@ env_variables: Dict[str, Callable[[], Any]] = {
# value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL":
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
# The tolerance of the kv cache size, if the difference between the
# actual kv cache size and the cached kv cache size is less than this value,
# then the cached kv cache size will be used.
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
lambda: int(
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
}
# end-env-vars-definition

View File

@@ -18,7 +18,10 @@
#
import atexit
import fcntl
import math
import os
import shutil
from contextlib import contextmanager, nullcontext
from enum import Enum
from threading import Lock
@@ -440,3 +443,77 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
return FusedMoEState.All2All
else:
return FusedMoEState.MC2
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
TORCHAIR_CACHE_DIR = os.getenv(
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
def get_torchair_current_work_dir(file_name=None):
if file_name is None:
return TORCHAIR_CACHE_DIR
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
def check_torchair_cache_exist():
res = False
torch_air_abs_path = get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
file_list = os.listdir(torch_air_abs_path)
if len(file_list) != 0:
res = True
return res
def check_kv_cache_bytes_cache_exist():
res = False
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
if os.path.exists(kv_cache_bytes_cache_abs_path):
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
if len(file_list) != 0:
res = True
return res
def read_kv_cache_bytes_from_file(rank) -> int:
kv_cache_bytes = -1
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
with file_lock(f, fcntl.LOCK_SH):
kv_cache_bytes = int(f.readline())
return kv_cache_bytes
@contextmanager
def file_lock(file_descriptor, lock_type):
fcntl.flock(file_descriptor, lock_type)
try:
yield
finally:
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
with file_lock(f, fcntl.LOCK_EX):
f.write(f"{kv_cache_bytes}")
def delete_torchair_cache_file():
torch_air_abs_path = get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
shutil.rmtree(torch_air_abs_path)

View File

@@ -76,9 +76,10 @@ from vllm_ascend.platform import NPUPlatform
from vllm_ascend.pool.metadata import PoolingMetadata
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
ProfileExecuteDuration,
check_torchair_cache_exist, is_310p,
maybe_converting_weight_acl_format,
vllm_version_is)
vllm_version_is, write_kv_cache_bytes_to_file)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -329,6 +330,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
attn_mask_len, self.dtype)
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -2274,6 +2276,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
with_prefill=False)
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
def capture_model(self) -> None:
start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0]
@@ -2283,24 +2299,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.torchair_graph_enabled:
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(
reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
with_prefill=False)
self._dummy_run(num_tokens,
is_compile=True,
with_prefill=False)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, graph_num)
if self.use_cached_npu_graph and not check_torchair_cache_exist():
# If caching is enabled but does not exist, we will compile the model twice. The first
# time is used to generate the cache, and the second time is used to load the cache to
# skip the overhead caused by Dynamo guard mechanism.
logger.info(
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
NPUPlatform.synchronize()
torch._dynamo.reset()
self.torchair_compiled_models.clear()
if self.use_cached_npu_graph:
logger.info(
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
0.3 * graph_num, 0.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
else:
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
if self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)
elif self.use_aclgraph:
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes

View File

@@ -36,11 +36,16 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.ascend_config import init_ascend_config
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import sleep_mode_enabled, try_register_lib
from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
check_torchair_cache_exist,
delete_torchair_cache_file,
read_kv_cache_bytes_from_file,
sleep_mode_enabled, try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -167,10 +172,35 @@ class NPUWorker(WorkerBase):
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
available_kv_cache_memory = int(
total_npu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
return int(available_kv_cache_memory)
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
logger.info(
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
)
if get_ascend_config().torchair_graph_config.enabled:
if check_torchair_cache_exist(
) and check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory
def execute_model(
self,