import numpy as np import torch import os from vllm.distributed import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size logger = init_logger(__name__) class BlockTable: def __init__( self, block_size: int, max_num_reqs: int, max_num_blocks_per_req: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): self.block_size = block_size self.max_num_reqs = max_num_reqs # self.max_num_blocks_per_req = max_num_blocks_per_req max_num_blocks_per_req = (max_num_blocks_per_req + env_blk_grp_size//16 - 1) // (env_blk_grp_size//16) * (env_blk_grp_size//16) self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device # self.block_table = torch.zeros( # (max_num_reqs, max_num_blocks_per_req), # device=self.device, # dtype=torch.int32, # ) self.block_table = self._make_buffer(max_num_reqs, max_num_blocks_per_req, dtype=torch.int32) # self.block_table_cpu = torch.zeros( # (max_num_reqs, max_num_blocks_per_req), # device="cpu", # dtype=torch.int32, # pin_memory=pin_memory, # ) # self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) # self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, # dtype=torch.int32, # device="cpu", # pin_memory=self.pin_memory) # self.slot_mapping_np = self.slot_mapping_cpu.numpy() # self.slot_mapping = torch.zeros(self.max_num_batched_tokens, # dtype=torch.int32, # device=self.device) self.slot_mapping = self._make_buffer(self.max_num_batched_tokens, dtype=torch.int32) try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0