92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
import torch
|
||
import torch_vacc
|
||
|
||
class BinCountTensorPooler:
|
||
"""
|
||
sampler类中(bin-count tensor)的缓冲器。
|
||
核心功能:为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。
|
||
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性,要考虑是否适用该pooler
|
||
|
||
pooler = BinCountTensorPooler(15169, "cpu")
|
||
list_1 = pooler.request_tensors(['1','3','aad'])
|
||
bin_count_buffers = pooler.request_tensors(req_ids)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
vocab_size: int,
|
||
device: torch.device,
|
||
max_cache_count: int = 20,
|
||
):
|
||
# 缓存容器:用列表维护请求ID与对应张量的顺序(FIFO淘汰)
|
||
self._cached_tensors: list[torch.Tensor] = []
|
||
self._cached_req_ids: list[str] = []
|
||
|
||
# 张量维度(+1 通常用于padding/特殊标记)
|
||
self.vocab_size: int = vocab_size + 1
|
||
self.device: torch.device = device
|
||
self.max_cache_count: int = max_cache_count
|
||
|
||
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
|
||
"""
|
||
批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。
|
||
|
||
Args:
|
||
req_ids: 待请求的请求ID列表
|
||
|
||
Returns:
|
||
与req_ids顺序对应的bin-count tensor
|
||
"""
|
||
if not req_ids:
|
||
return [] # 空输入快速返回,避免无效循环
|
||
|
||
out_tensors = []
|
||
for req_id in req_ids:
|
||
if req_id not in self._cached_req_ids:
|
||
# 注册新的req, 增加cached-tensor
|
||
self._add_new_cache(req_id)
|
||
|
||
# 快速获取缓存索引(后续可优化为dict映射,提升性能)
|
||
cache_idx = self._cached_req_ids.index(req_id)
|
||
out_tensors.append(self._cached_tensors[cache_idx])
|
||
|
||
return out_tensors
|
||
|
||
def _add_new_cache(self, req_id: str) -> None:
|
||
"""
|
||
新增缓存项:创建张量并加入缓存,超出最大数量时按FIFO淘汰最早项。
|
||
私有方法:封装内部逻辑,避免外部直接调用
|
||
"""
|
||
# FIFO淘汰:缓存满时移除头部(最早加入的项)
|
||
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
|
||
self._cached_tensors.pop(0)
|
||
self._cached_req_ids.pop(0)
|
||
|
||
# 创建二进制计数张量(int32类型,初始全0)
|
||
bin_count_tensor = torch.zeros(
|
||
size=(1, self.vocab_size),
|
||
dtype=torch.int32,
|
||
device=self.device,
|
||
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
|
||
)
|
||
|
||
# 同步添加ID和张量(保证两个列表索引一致)
|
||
self._cached_req_ids.append(req_id)
|
||
self._cached_tensors.append(bin_count_tensor)
|
||
|
||
def clear_cache(self) -> None:
|
||
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
|
||
self._cached_tensors.clear()
|
||
self._cached_req_ids.clear()
|
||
|
||
def get_cache_status(self) -> dict:
|
||
"""获取缓存状态(可选扩展方法,方便监控)"""
|
||
return {
|
||
"current_cache_count": len(self._cached_req_ids),
|
||
"max_cache_count": self.max_cache_count,
|
||
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
|
||
"tensor_shape": (1, self.vocab_size),
|
||
"tensor_dtype": torch.int32,
|
||
"device": str(self.device),
|
||
}
|