Files

92 lines
3.5 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import torch
import torch_vacc
class BinCountTensorPooler:
"""
sampler类中bin-count tensor的缓冲器
核心功能为apply_penalties 提供创建好的bin-count, 为避免重复重头计算
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性要考虑是否适用该pooler
pooler = BinCountTensorPooler(15169, "cpu")
list_1 = pooler.request_tensors(['1','3','aad'])
bin_count_buffers = pooler.request_tensors(req_ids)
"""
def __init__(
self,
vocab_size: int,
device: torch.device,
max_cache_count: int = 20,
):
# 缓存容器用列表维护请求ID与对应张量的顺序FIFO淘汰
self._cached_tensors: list[torch.Tensor] = []
self._cached_req_ids: list[str] = []
# 张量维度(+1 通常用于padding/特殊标记)
self.vocab_size: int = vocab_size + 1
self.device: torch.device = device
self.max_cache_count: int = max_cache_count
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
"""
批量请求张量已缓存的直接返回未缓存的创建并加入缓存
Args:
req_ids: 待请求的请求ID列表
Returns:
与req_ids顺序对应的bin-count tensor
"""
if not req_ids:
return [] # 空输入快速返回,避免无效循环
out_tensors = []
for req_id in req_ids:
if req_id not in self._cached_req_ids:
# 注册新的req, 增加cached-tensor
self._add_new_cache(req_id)
# 快速获取缓存索引后续可优化为dict映射提升性能
cache_idx = self._cached_req_ids.index(req_id)
out_tensors.append(self._cached_tensors[cache_idx])
return out_tensors
def _add_new_cache(self, req_id: str) -> None:
"""
新增缓存项创建张量并加入缓存超出最大数量时按FIFO淘汰最早项
私有方法封装内部逻辑避免外部直接调用
"""
# FIFO淘汰缓存满时移除头部最早加入的项
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
self._cached_tensors.pop(0)
self._cached_req_ids.pop(0)
# 创建二进制计数张量int32类型初始全0
bin_count_tensor = torch.zeros(
size=(1, self.vocab_size),
dtype=torch.int32,
device=self.device,
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
)
# 同步添加ID和张量保证两个列表索引一致
self._cached_req_ids.append(req_id)
self._cached_tensors.append(bin_count_tensor)
def clear_cache(self) -> None:
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
self._cached_tensors.clear()
self._cached_req_ids.clear()
def get_cache_status(self) -> dict:
"""获取缓存状态(可选扩展方法,方便监控)"""
return {
"current_cache_count": len(self._cached_req_ids),
"max_cache_count": self.max_cache_count,
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
"tensor_shape": (1, self.vocab_size),
"tensor_dtype": torch.int32,
"device": str(self.device),
}