diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/cache_engine.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/cache_engine.py index 411484f..3be0fc2 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/cache_engine.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/cache_engine.py @@ -24,8 +24,29 @@ def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache( ============================= Modify by vllm_mlu ============================= - @brief: add kv_cache_scale for int8 support - ''' + @brief: add kv_cache_scale for int8 support; + cap num_blocks to avoid exceeding CNNL int32 element limit + ''' + # CNNL operators have a max supported tensor element count of INT32_MAX. + # If the kv_cache tensor would exceed this limit, reduce num_blocks. + CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1 + total_elements = 1 + for dim in kv_cache_shape: + total_elements *= dim + if total_elements > CNNL_MAX_TENSOR_ELEMENTS: + # Calculate the max num_blocks that fits within the limit. + # kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size) + elements_per_block = total_elements // num_blocks + max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block + logger.warning( + "KV cache tensor elements (%d) exceed CNNL max (%d). " + "Reducing num_blocks from %d to %d.", + total_elements, CNNL_MAX_TENSOR_ELEMENTS, + num_blocks, max_num_blocks) + num_blocks = max_num_blocks + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape( num_blocks, self.block_size, self.num_kv_heads) pin_memory = is_pin_memory_available() if device == "cpu" else False diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/mlu_worker.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/mlu_worker.py index c6c0328..5752c2e 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/mlu_worker.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/mlu_worker.py @@ -95,6 +95,30 @@ class MLUWorker_V2(MLUWorker): num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + # Cap num_gpu_blocks to avoid exceeding CNNL's int32 tensor element + # limit. CNNL operators do not support tensors with more than + # 2^31 - 1 elements. The KV cache shape is typically + # (2, num_blocks, num_kv_heads, block_size, head_size), and when + # num_blocks is very large (e.g. for tiny models with huge free + # memory), the total element count can overflow. + CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1 + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads( + self.parallel_config) + head_size = self.model_config.get_head_size() + # kv_cache_shape = (2, num_blocks, num_kv_heads, block_size, head_size) + elements_per_block = 2 * num_kv_heads * block_size * head_size + if elements_per_block > 0: + max_blocks_by_cnnl = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block + if num_gpu_blocks > max_blocks_by_cnnl: + logger.warning( + "Reducing num_gpu_blocks from %d to %d to stay within " + "CNNL max tensor element limit (%d). " + "elements_per_block=%d", + num_gpu_blocks, max_blocks_by_cnnl, + CNNL_MAX_TENSOR_ELEMENTS, elements_per_block) + num_gpu_blocks = max_blocks_by_cnnl + logger.info( "Memory profiling results: total_gpu_memory=%.2fGiB" " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"