import torch import torch_vacc class BinCountTensorPooler: """ sampler类中(bin-count tensor)的缓冲器。 核心功能:为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。 注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性,要考虑是否适用该pooler pooler = BinCountTensorPooler(15169, "cpu") list_1 = pooler.request_tensors(['1','3','aad']) bin_count_buffers = pooler.request_tensors(req_ids) """ def __init__( self, vocab_size: int, device: torch.device, max_cache_count: int = 20, ): # 缓存容器:用列表维护请求ID与对应张量的顺序(FIFO淘汰) self._cached_tensors: list[torch.Tensor] = [] self._cached_req_ids: list[str] = [] # 张量维度(+1 通常用于padding/特殊标记) self.vocab_size: int = vocab_size + 1 self.device: torch.device = device self.max_cache_count: int = max_cache_count def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]: """ 批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。 Args: req_ids: 待请求的请求ID列表 Returns: 与req_ids顺序对应的bin-count tensor """ if not req_ids: return [] # 空输入快速返回,避免无效循环 out_tensors = [] for req_id in req_ids: if req_id not in self._cached_req_ids: # 注册新的req, 增加cached-tensor self._add_new_cache(req_id) # 快速获取缓存索引(后续可优化为dict映射,提升性能) cache_idx = self._cached_req_ids.index(req_id) out_tensors.append(self._cached_tensors[cache_idx]) return out_tensors def _add_new_cache(self, req_id: str) -> None: """ 新增缓存项:创建张量并加入缓存,超出最大数量时按FIFO淘汰最早项。 私有方法:封装内部逻辑,避免外部直接调用 """ # FIFO淘汰:缓存满时移除头部(最早加入的项) if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题) self._cached_tensors.pop(0) self._cached_req_ids.pop(0) # 创建二进制计数张量(int32类型,初始全0) bin_count_tensor = torch.zeros( size=(1, self.vocab_size), dtype=torch.int32, device=self.device, requires_grad=False, # 明确禁用梯度(计数张量无需反向传播) ) # 同步添加ID和张量(保证两个列表索引一致) self._cached_req_ids.append(req_id) self._cached_tensors.append(bin_count_tensor) def clear_cache(self) -> None: """清空所有缓存(可选扩展方法,方便外部手动清理)""" self._cached_tensors.clear() self._cached_req_ids.clear() def get_cache_status(self) -> dict: """获取缓存状态(可选扩展方法,方便监控)""" return { "current_cache_count": len(self._cached_req_ids), "max_cache_count": self.max_cache_count, "cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改 "tensor_shape": (1, self.vocab_size), "tensor_dtype": torch.int32, "device": str(self.device), }