forked from EngineX-Cambricon/enginex-mlu370-vllm
fixing kvcache bug
This commit is contained in:
@@ -24,8 +24,29 @@ def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add kv_cache_scale for int8 support
|
||||
'''
|
||||
@brief: add kv_cache_scale for int8 support;
|
||||
cap num_blocks to avoid exceeding CNNL int32 element limit
|
||||
'''
|
||||
# CNNL operators have a max supported tensor element count of INT32_MAX.
|
||||
# If the kv_cache tensor would exceed this limit, reduce num_blocks.
|
||||
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
|
||||
total_elements = 1
|
||||
for dim in kv_cache_shape:
|
||||
total_elements *= dim
|
||||
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
|
||||
# Calculate the max num_blocks that fits within the limit.
|
||||
# kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
elements_per_block = total_elements // num_blocks
|
||||
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
|
||||
logger.warning(
|
||||
"KV cache tensor elements (%d) exceed CNNL max (%d). "
|
||||
"Reducing num_blocks from %d to %d.",
|
||||
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
|
||||
num_blocks, max_num_blocks)
|
||||
num_blocks = max_num_blocks
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
|
||||
kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
|
||||
@@ -95,6 +95,30 @@ class MLUWorker_V2(MLUWorker):
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
# Cap num_gpu_blocks to avoid exceeding CNNL's int32 tensor element
|
||||
# limit. CNNL operators do not support tensors with more than
|
||||
# 2^31 - 1 elements. The KV cache shape is typically
|
||||
# (2, num_blocks, num_kv_heads, block_size, head_size), and when
|
||||
# num_blocks is very large (e.g. for tiny models with huge free
|
||||
# memory), the total element count can overflow.
|
||||
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
|
||||
block_size = self.cache_config.block_size
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
# kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
elements_per_block = 2 * num_kv_heads * block_size * head_size
|
||||
if elements_per_block > 0:
|
||||
max_blocks_by_cnnl = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
|
||||
if num_gpu_blocks > max_blocks_by_cnnl:
|
||||
logger.warning(
|
||||
"Reducing num_gpu_blocks from %d to %d to stay within "
|
||||
"CNNL max tensor element limit (%d). "
|
||||
"elements_per_block=%d",
|
||||
num_gpu_blocks, max_blocks_by_cnnl,
|
||||
CNNL_MAX_TENSOR_ELEMENTS, elements_per_block)
|
||||
num_gpu_blocks = max_blocks_by_cnnl
|
||||
|
||||
logger.info(
|
||||
"Memory profiling results: total_gpu_memory=%.2fGiB"
|
||||
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
|
||||
|
||||
Reference in New Issue
Block a user