初始化项目,由ModelHub XC社区提供模型
Model: GSAI-ML/ReFusion Source: Original Platform
This commit is contained in:
30
diffusion_cache_utils.py
Normal file
30
diffusion_cache_utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from typing import Optional, List, Tuple, Dict, Any
|
||||
|
||||
class DiffusionDynamicCache(DynamicCache):
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None):
|
||||
super().__init__(num_hidden_layers)
|
||||
|
||||
def full_update(
|
||||
self,
|
||||
new_kv: Tuple,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
for i, (key, val) in enumerate(new_kv):
|
||||
self.key_cache[i] = torch.cat([self.key_cache[i], key], dim=-2)
|
||||
self.value_cache[i] = torch.cat([self.value_cache[i], val], dim=-2)
|
||||
|
||||
def select_partial(
|
||||
self,
|
||||
indices: torch.Tensor,
|
||||
):
|
||||
for i in range(len(self.key_cache)):
|
||||
self.key_cache[i] = self.key_cache[i][:, :, indices, :]
|
||||
self.value_cache[i] = self.value_cache[i][:, :, indices, :]
|
||||
|
||||
def batch_select_minibatch(self, indices: torch.Tensor):
|
||||
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
||||
for layer_idx in range(len(self)):
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx][:indices, ...]
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx][:indices, ...]
|
||||
Reference in New Issue
Block a user