Files
2026-04-02 04:55:00 +00:00

92 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch_vacc
class BinCountTensorPooler:
"""
sampler类中bin-count tensor的缓冲器。
核心功能为apply_penalties 提供创建好的bin-count, 为避免重复重头计算。
注意: 如果系统中的req_id未能确保在max_count的数量下的唯一性要考虑是否适用该pooler
pooler = BinCountTensorPooler(15169, "cpu")
list_1 = pooler.request_tensors(['1','3','aad'])
bin_count_buffers = pooler.request_tensors(req_ids)
"""
def __init__(
self,
vocab_size: int,
device: torch.device,
max_cache_count: int = 20,
):
# 缓存容器用列表维护请求ID与对应张量的顺序FIFO淘汰
self._cached_tensors: list[torch.Tensor] = []
self._cached_req_ids: list[str] = []
# 张量维度(+1 通常用于padding/特殊标记)
self.vocab_size: int = vocab_size + 1
self.device: torch.device = device
self.max_cache_count: int = max_cache_count
def request_tensors(self, req_ids: list[str]) -> list[torch.Tensor]:
"""
批量请求张量:已缓存的直接返回,未缓存的创建并加入缓存。
Args:
req_ids: 待请求的请求ID列表
Returns:
与req_ids顺序对应的bin-count tensor
"""
if not req_ids:
return [] # 空输入快速返回,避免无效循环
out_tensors = []
for req_id in req_ids:
if req_id not in self._cached_req_ids:
# 注册新的req, 增加cached-tensor
self._add_new_cache(req_id)
# 快速获取缓存索引后续可优化为dict映射提升性能
cache_idx = self._cached_req_ids.index(req_id)
out_tensors.append(self._cached_tensors[cache_idx])
return out_tensors
def _add_new_cache(self, req_id: str) -> None:
"""
新增缓存项创建张量并加入缓存超出最大数量时按FIFO淘汰最早项。
私有方法:封装内部逻辑,避免外部直接调用
"""
# FIFO淘汰缓存满时移除头部最早加入的项
if len(self._cached_req_ids) >= self.max_cache_count: # >= 更严谨(避免边界值问题)
self._cached_tensors.pop(0)
self._cached_req_ids.pop(0)
# 创建二进制计数张量int32类型初始全0
bin_count_tensor = torch.zeros(
size=(1, self.vocab_size),
dtype=torch.int32,
device=self.device,
requires_grad=False, # 明确禁用梯度(计数张量无需反向传播)
)
# 同步添加ID和张量保证两个列表索引一致
self._cached_req_ids.append(req_id)
self._cached_tensors.append(bin_count_tensor)
def clear_cache(self) -> None:
"""清空所有缓存(可选扩展方法,方便外部手动清理)"""
self._cached_tensors.clear()
self._cached_req_ids.clear()
def get_cache_status(self) -> dict:
"""获取缓存状态(可选扩展方法,方便监控)"""
return {
"current_cache_count": len(self._cached_req_ids),
"max_cache_count": self.max_cache_count,
"cached_req_ids": self._cached_req_ids.copy(), # 返回副本避免外部修改
"tensor_shape": (1, self.vocab_size),
"tensor_dtype": torch.int32,
"device": str(self.device),
}