feat: add kv cache memory cache and skip dynamo guard (#1549)
### What this PR does / why we need it?
1、Sometimes loading torchair cache will fail because of the floating of
npu memory, so this pr add a new cache to save the old kv cache bytes to
avoid the possible crash while loading the torchair graph cache.
2、When caching is enabled and does not exist, the first compilation
introduces the overhead of Dynamo Gurad. So in this case, we will
compile them directly twice to skip them (This will bring 3-4 ms of tpot
optimization)
### Does this PR introduce _any_ user-facing change?
Add a new env `VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE` to
control kv cache floating tolerance
### How was this patch tested?
- vLLM version: v0.9.1
- vLLM main:
1fd471e957
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -36,11 +36,16 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import sleep_mode_enabled, try_register_lib
|
||||
from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist,
|
||||
check_torchair_cache_exist,
|
||||
delete_torchair_cache_file,
|
||||
read_kv_cache_bytes_from_file,
|
||||
sleep_mode_enabled, try_register_lib)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@@ -167,10 +172,35 @@ class NPUWorker(WorkerBase):
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
available_kv_cache_memory = int(
|
||||
total_npu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
return int(available_kv_cache_memory)
|
||||
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
||||
logger.info(
|
||||
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
|
||||
)
|
||||
if get_ascend_config().torchair_graph_config.enabled:
|
||||
if check_torchair_cache_exist(
|
||||
) and check_kv_cache_bytes_cache_exist():
|
||||
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
|
||||
torch.distributed.get_rank())
|
||||
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
|
||||
logger.info(
|
||||
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
|
||||
)
|
||||
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
|
||||
return old_kv_cache_bytes
|
||||
else:
|
||||
logger.info(
|
||||
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
|
||||
)
|
||||
delete_torchair_cache_file()
|
||||
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
|
||||
available_kv_cache_memory -= bytes_floating_tolerance
|
||||
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
|
||||
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
|
||||
|
||||
return available_kv_cache_memory
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user