fixed kvcache bug
This commit is contained in:
@@ -28,19 +28,19 @@ def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
|
|||||||
cap num_blocks to avoid exceeding CNNL int32 element limit
|
cap num_blocks to avoid exceeding CNNL int32 element limit
|
||||||
'''
|
'''
|
||||||
# CNNL operators have a max supported tensor element count of INT32_MAX.
|
# CNNL operators have a max supported tensor element count of INT32_MAX.
|
||||||
# If the kv_cache tensor would exceed this limit, reduce num_blocks.
|
# num_blocks should already be capped by determine_num_available_blocks,
|
||||||
|
# this is a defensive check to catch any edge cases.
|
||||||
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
|
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
|
||||||
total_elements = 1
|
total_elements = 1
|
||||||
for dim in kv_cache_shape:
|
for dim in kv_cache_shape:
|
||||||
total_elements *= dim
|
total_elements *= dim
|
||||||
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
|
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
|
||||||
# Calculate the max num_blocks that fits within the limit.
|
|
||||||
# kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size)
|
|
||||||
elements_per_block = total_elements // num_blocks
|
elements_per_block = total_elements // num_blocks
|
||||||
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
|
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"KV cache tensor elements (%d) exceed CNNL max (%d). "
|
"KV cache tensor elements (%d) exceed CNNL max (%d). "
|
||||||
"Reducing num_blocks from %d to %d.",
|
"Reducing num_blocks from %d to %d. This indicates "
|
||||||
|
"determine_num_available_blocks did not cap correctly.",
|
||||||
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
|
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
|
||||||
num_blocks, max_num_blocks)
|
num_blocks, max_num_blocks)
|
||||||
num_blocks = max_num_blocks
|
num_blocks = max_num_blocks
|
||||||
|
|||||||
Reference in New Issue
Block a user