30 lines
1.2 KiB
Python
30 lines
1.2 KiB
Python
import torch
|
|
from transformers.cache_utils import DynamicCache
|
|
from typing import Optional, List, Tuple, Dict, Any
|
|
|
|
class DiffusionDynamicCache(DynamicCache):
|
|
def __init__(self, num_hidden_layers: Optional[int] = None):
|
|
super().__init__(num_hidden_layers)
|
|
|
|
def full_update(
|
|
self,
|
|
new_kv: Tuple,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
for i, (key, val) in enumerate(new_kv):
|
|
self.key_cache[i] = torch.cat([self.key_cache[i], key], dim=-2)
|
|
self.value_cache[i] = torch.cat([self.value_cache[i], val], dim=-2)
|
|
|
|
def select_partial(
|
|
self,
|
|
indices: torch.Tensor,
|
|
):
|
|
for i in range(len(self.key_cache)):
|
|
self.key_cache[i] = self.key_cache[i][:, :, indices, :]
|
|
self.value_cache[i] = self.value_cache[i][:, :, indices, :]
|
|
|
|
def batch_select_minibatch(self, indices: torch.Tensor):
|
|
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
|
for layer_idx in range(len(self)):
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][:indices, ...]
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][:indices, ...] |