Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/worker/cache_engine.py
2026-02-11 17:47:14 +08:00

174 lines
6.7 KiB
Python

"""CacheEngine class for managing the KV cache."""
from typing import List, Tuple, Optional
import torch
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available, get_dtype_size
from vllm.worker.cache_engine import CacheEngine
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[List[torch.Tensor]]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add kv_cache_scale for int8 support;
cap num_blocks to avoid exceeding CNNL int32 element limit
'''
# CNNL operators have a max supported tensor element count of INT32_MAX.
# num_blocks should already be capped by determine_num_available_blocks,
# this is a defensive check to catch any edge cases.
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
total_elements = 1
for dim in kv_cache_shape:
total_elements *= dim
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
elements_per_block = total_elements // num_blocks
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
logger.warning(
"KV cache tensor elements (%d) exceed CNNL max (%d). "
"Reducing num_blocks from %d to %d. This indicates "
"determine_num_available_blocks did not cap correctly.",
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
num_blocks, max_num_blocks)
num_blocks = max_num_blocks
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape(
num_blocks, self.block_size, self.num_kv_heads)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[List[torch.Tensor]] = []
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
kv_cache_ = torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
if self.dtype == torch.int8:
kv_cache_scale_ = torch.zeros(kv_cache_scales_shape,
dtype=torch.float32,
pin_memory=pin_memory,
device=device)
else:
kv_cache_scale_ = torch.tensor([],
dtype=torch.float32,
device=device)
kv_cache.append([kv_cache_, kv_cache_scale_])
'''
==================
End of MLU Hijack
==================
'''
return kv_cache
def vllm__worker__cache_engine__CacheEngine__swap_in(self, src_to_dst: torch.Tensor) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: swap kv_cache_scale for int8 support
'''
for i in range(self.num_attention_layers):
# swap kv_cache
self.attn_backend.swap_blocks(self.cpu_cache[i][0], self.gpu_cache[i][0],
src_to_dst)
if self.dtype == torch.int8:
# swap kv_cache_scale
self.attn_backend.swap_blocks(self.cpu_cache[i][1], self.gpu_cache[i][1],
src_to_dst)
'''
==================
End of MLU Hijack
==================
'''
def vllm__worker__cache_engine__CacheEngine__swap_out(self, src_to_dst: torch.Tensor) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: swap kv_cache_scale for int8 support
'''
for i in range(self.num_attention_layers):
# swap kv_cache
self.attn_backend.swap_blocks(self.gpu_cache[i][0], self.cpu_cache[i][0],
src_to_dst)
if self.dtype == torch.int8:
# swap kv_cache_scale
self.attn_backend.swap_blocks(self.gpu_cache[i][1], self.cpu_cache[i][1],
src_to_dst)
'''
==================
End of MLU Hijack
==================
'''
vllm__worker__cache_engine__CacheEngine__get_cache_block_size__org = CacheEngine.get_cache_block_size
@staticmethod
def vllm__worker__cache_engine__CacheEngine__get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
kv_cache_total_size = vllm__worker__cache_engine__CacheEngine__get_cache_block_size__org(
cache_config=cache_config,
model_config=model_config,
parallel_config=parallel_config
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: compute kv_cache_scale total size
'''
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(parallel_config)
kv_cache_scale_total_size = 0
if cache_config.cache_dtype == 'int8':
key_cache_scale_block = cache_config.block_size * num_heads
value_cache_scale_block = key_cache_scale_block
scale_total = num_attention_layers * (key_cache_scale_block + value_cache_scale_block)
dtype_size = get_dtype_size(torch.float32)
kv_cache_scale_total_size = dtype_size * scale_total
return kv_cache_total_size + kv_cache_scale_total_size
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine._allocate_kv_cache,
vllm__worker__cache_engine__CacheEngine___allocate_kv_cache)
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine.swap_in,
vllm__worker__cache_engine__CacheEngine__swap_in)
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine.swap_out,
vllm__worker__cache_engine__CacheEngine__swap_out)
MluHijackObject.apply_hijack(CacheEngine,
CacheEngine.get_cache_block_size,
vllm__worker__cache_engine__CacheEngine__get_cache_block_size)