112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.distributed import get_dcp_group
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.worker.block_table import BlockTable
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class BlockTable_MluHijack(BlockTable):
|
|
|
|
def __init__(
|
|
self,
|
|
block_size: int,
|
|
max_num_reqs: int,
|
|
max_num_blocks_per_req: int,
|
|
max_num_batched_tokens: int,
|
|
pin_memory: bool,
|
|
device: torch.device,
|
|
kernel_block_size: int,
|
|
dcp_kv_cache_interleave_size: int,
|
|
):
|
|
"""
|
|
Args:
|
|
block_size: Block size used for KV cache memory allocation
|
|
max_num_reqs: Maximum number of concurrent requests supported.
|
|
max_num_blocks_per_req: Maximum number of blocks per request.
|
|
max_num_batched_tokens: Maximum number of tokens in a batch.
|
|
pin_memory: Whether to pin memory for faster GPU transfers.
|
|
device: Target device for the block table.
|
|
kernel_block_size: The block_size of underlying attention kernel.
|
|
Will be the same as `block_size` if `block_size` is supported
|
|
by the attention kernel.
|
|
"""
|
|
self.max_num_reqs = max_num_reqs
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
self.pin_memory = pin_memory
|
|
self.device = device
|
|
|
|
if kernel_block_size == block_size:
|
|
# Standard case: allocation and computation use same block size
|
|
# No block splitting needed, direct mapping
|
|
self.block_size = block_size
|
|
self.blocks_per_kv_block = 1
|
|
self.use_hybrid_blocks = False
|
|
else:
|
|
# Hybrid case: allocation block size differs from kernel block size
|
|
# Memory blocks are subdivided to match kernel requirements
|
|
# Example: 32-token memory blocks with 16-token kernel blocks
|
|
# → Each memory block corresponds to 2 kernel blocks
|
|
if block_size % kernel_block_size != 0:
|
|
raise ValueError(
|
|
f"kernel_block_size {kernel_block_size} must divide "
|
|
f"kv_manager_block_size size {block_size} evenly"
|
|
)
|
|
|
|
self.block_size = kernel_block_size
|
|
self.blocks_per_kv_block = block_size // kernel_block_size
|
|
self.use_hybrid_blocks = True
|
|
|
|
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
|
|
|
self.block_table = self._make_buffer(
|
|
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
|
)
|
|
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: change slot_mapping dtype for int64 to int32
|
|
'''
|
|
self.slot_mapping = self._make_buffer(
|
|
self.max_num_batched_tokens, dtype=torch.int32
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
if self.use_hybrid_blocks:
|
|
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
|
|
1, -1
|
|
)
|
|
else:
|
|
self._kernel_block_arange = None
|
|
|
|
try:
|
|
self.dcp_world_size = get_dcp_group().world_size
|
|
self.dcp_rank = get_dcp_group().rank_in_group
|
|
except AssertionError:
|
|
# DCP might not be initialized in testing
|
|
self.dcp_world_size = 1
|
|
self.dcp_rank = 0
|
|
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
|
|
|
|
|
MluHijackObject.apply_hijack(
|
|
BlockTable,
|
|
BlockTable.__init__,
|
|
BlockTable_MluHijack.__init__
|
|
) |