174 lines
6.7 KiB
Python
174 lines
6.7 KiB
Python
"""CacheEngine class for managing the KV cache."""
|
|
from typing import List, Tuple, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import is_pin_memory_available, get_dtype_size
|
|
from vllm.worker.cache_engine import CacheEngine
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def vllm__worker__cache_engine__CacheEngine___allocate_kv_cache(
|
|
self,
|
|
num_blocks: int,
|
|
device: str,
|
|
) -> List[List[torch.Tensor]]:
|
|
"""Allocates KV cache on the specified device."""
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add kv_cache_scale for int8 support;
|
|
cap num_blocks to avoid exceeding CNNL int32 element limit
|
|
'''
|
|
# CNNL operators have a max supported tensor element count of INT32_MAX.
|
|
# num_blocks should already be capped by determine_num_available_blocks,
|
|
# this is a defensive check to catch any edge cases.
|
|
CNNL_MAX_TENSOR_ELEMENTS = 2**31 - 1
|
|
total_elements = 1
|
|
for dim in kv_cache_shape:
|
|
total_elements *= dim
|
|
if total_elements > CNNL_MAX_TENSOR_ELEMENTS:
|
|
elements_per_block = total_elements // num_blocks
|
|
max_num_blocks = CNNL_MAX_TENSOR_ELEMENTS // elements_per_block
|
|
logger.warning(
|
|
"KV cache tensor elements (%d) exceed CNNL max (%d). "
|
|
"Reducing num_blocks from %d to %d. This indicates "
|
|
"determine_num_available_blocks did not cap correctly.",
|
|
total_elements, CNNL_MAX_TENSOR_ELEMENTS,
|
|
num_blocks, max_num_blocks)
|
|
num_blocks = max_num_blocks
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
|
|
|
kv_cache_scales_shape = self.attn_backend.get_kv_cache_scale_shape(
|
|
num_blocks, self.block_size, self.num_kv_heads)
|
|
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
|
kv_cache: List[List[torch.Tensor]] = []
|
|
for _ in range(self.num_attention_layers):
|
|
# null block in CpuGpuBlockAllocator requires at least that
|
|
# block to be zeroed-out.
|
|
# We zero-out everything for simplicity.
|
|
kv_cache_ = torch.zeros(kv_cache_shape,
|
|
dtype=self.dtype,
|
|
pin_memory=pin_memory,
|
|
device=device)
|
|
if self.dtype == torch.int8:
|
|
kv_cache_scale_ = torch.zeros(kv_cache_scales_shape,
|
|
dtype=torch.float32,
|
|
pin_memory=pin_memory,
|
|
device=device)
|
|
else:
|
|
kv_cache_scale_ = torch.tensor([],
|
|
dtype=torch.float32,
|
|
device=device)
|
|
kv_cache.append([kv_cache_, kv_cache_scale_])
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return kv_cache
|
|
|
|
|
|
def vllm__worker__cache_engine__CacheEngine__swap_in(self, src_to_dst: torch.Tensor) -> None:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: swap kv_cache_scale for int8 support
|
|
'''
|
|
for i in range(self.num_attention_layers):
|
|
# swap kv_cache
|
|
self.attn_backend.swap_blocks(self.cpu_cache[i][0], self.gpu_cache[i][0],
|
|
src_to_dst)
|
|
if self.dtype == torch.int8:
|
|
# swap kv_cache_scale
|
|
self.attn_backend.swap_blocks(self.cpu_cache[i][1], self.gpu_cache[i][1],
|
|
src_to_dst)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
def vllm__worker__cache_engine__CacheEngine__swap_out(self, src_to_dst: torch.Tensor) -> None:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: swap kv_cache_scale for int8 support
|
|
'''
|
|
for i in range(self.num_attention_layers):
|
|
# swap kv_cache
|
|
self.attn_backend.swap_blocks(self.gpu_cache[i][0], self.cpu_cache[i][0],
|
|
src_to_dst)
|
|
if self.dtype == torch.int8:
|
|
# swap kv_cache_scale
|
|
self.attn_backend.swap_blocks(self.gpu_cache[i][1], self.cpu_cache[i][1],
|
|
src_to_dst)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
vllm__worker__cache_engine__CacheEngine__get_cache_block_size__org = CacheEngine.get_cache_block_size
|
|
|
|
@staticmethod
|
|
def vllm__worker__cache_engine__CacheEngine__get_cache_block_size(
|
|
cache_config: CacheConfig,
|
|
model_config: ModelConfig,
|
|
parallel_config: ParallelConfig,
|
|
) -> int:
|
|
kv_cache_total_size = vllm__worker__cache_engine__CacheEngine__get_cache_block_size__org(
|
|
cache_config=cache_config,
|
|
model_config=model_config,
|
|
parallel_config=parallel_config
|
|
)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: compute kv_cache_scale total size
|
|
'''
|
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
|
num_attention_layers = model_config.get_num_attention_layers(parallel_config)
|
|
|
|
kv_cache_scale_total_size = 0
|
|
if cache_config.cache_dtype == 'int8':
|
|
key_cache_scale_block = cache_config.block_size * num_heads
|
|
value_cache_scale_block = key_cache_scale_block
|
|
scale_total = num_attention_layers * (key_cache_scale_block + value_cache_scale_block)
|
|
dtype_size = get_dtype_size(torch.float32)
|
|
kv_cache_scale_total_size = dtype_size * scale_total
|
|
|
|
return kv_cache_total_size + kv_cache_scale_total_size
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
MluHijackObject.apply_hijack(CacheEngine,
|
|
CacheEngine._allocate_kv_cache,
|
|
vllm__worker__cache_engine__CacheEngine___allocate_kv_cache)
|
|
MluHijackObject.apply_hijack(CacheEngine,
|
|
CacheEngine.swap_in,
|
|
vllm__worker__cache_engine__CacheEngine__swap_in)
|
|
MluHijackObject.apply_hijack(CacheEngine,
|
|
CacheEngine.swap_out,
|
|
vllm__worker__cache_engine__CacheEngine__swap_out)
|
|
MluHijackObject.apply_hijack(CacheEngine,
|
|
CacheEngine.get_cache_block_size,
|
|
vllm__worker__cache_engine__CacheEngine__get_cache_block_size)
|