# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project # SPDX-License-Identifier: Apache-2.0 import numpy as np import torch from vllm.distributed import get_dcp_group from vllm.logger import init_logger from vllm.v1.worker.block_table import BlockTable from vllm_mlu.mlu_hijack_utils import MluHijackObject logger = init_logger(__name__) class BlockTable_MluHijack(BlockTable): def __init__( self, block_size: int, max_num_reqs: int, max_num_blocks_per_req: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, kernel_block_size: int, dcp_kv_cache_interleave_size: int, ): """ Args: block_size: Block size used for KV cache memory allocation max_num_reqs: Maximum number of concurrent requests supported. max_num_blocks_per_req: Maximum number of blocks per request. max_num_batched_tokens: Maximum number of tokens in a batch. pin_memory: Whether to pin memory for faster GPU transfers. device: Target device for the block table. kernel_block_size: The block_size of underlying attention kernel. Will be the same as `block_size` if `block_size` is supported by the attention kernel. """ self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device if kernel_block_size == block_size: # Standard case: allocation and computation use same block size # No block splitting needed, direct mapping self.block_size = block_size self.blocks_per_kv_block = 1 self.use_hybrid_blocks = False else: # Hybrid case: allocation block size differs from kernel block size # Memory blocks are subdivided to match kernel requirements # Example: 32-token memory blocks with 16-token kernel blocks # → Each memory block corresponds to 2 kernel blocks if block_size % kernel_block_size != 0: raise ValueError( f"kernel_block_size {kernel_block_size} must divide " f"kv_manager_block_size size {block_size} evenly" ) self.block_size = kernel_block_size self.blocks_per_kv_block = block_size // kernel_block_size self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block self.block_table = self._make_buffer( self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 ) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) ''' ============================= Modify by vllm_mlu ============================= @brief: change slot_mapping dtype for int64 to int32 ''' self.slot_mapping = self._make_buffer( self.max_num_batched_tokens, dtype=torch.int32 ) ''' ================== End of MLU Hijack ================== ''' if self.use_hybrid_blocks: self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape( 1, -1 ) else: self._kernel_block_arange = None try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size MluHijackObject.apply_hijack( BlockTable, BlockTable.__init__, BlockTable_MluHijack.__init__ )