初始化项目,由ModelHub XC社区提供模型

Model: transformers-community/sink_cache
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-05-20 00:27:08 +08:00
commit 23b382f097
10 changed files with 151789 additions and 0 deletions

36
.gitattributes vendored Normal file
View File

@@ -0,0 +1,36 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
*.json filter=lfs diff=lfs merge=lfs -text

116
README.md Normal file
View File

@@ -0,0 +1,116 @@
---
library_name: transformers
tags:
- custom_generate
---
## Description
Implementation of the KV cache introduced in the [Attention Sinks paper](https://huggingface.co/papers/2309.17453).
It allows the model to generate beyond the length of its context window, without losing fluency in the conversation.
This is done by always keeping the first few tokens ("sink tokens") in the KV cache, as models often pay a large
amount of attention to them. As it discards past non-sink tokens, the model will lose the ability to generate tokens
that depend on the context that was discarded. It's also a solution to contain the memory footprint of the KV cache.
This implementation matches the `SinkCache` class present in `transformers<4.53.0`.
![Sink Cache diagram from the original paper](https://arxiv.org/html/2309.17453v4/x1.png)
<!-- TODO (joao): add `transformers chat` example -->
## Base model
- [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)
## Model compatibility
- Decoder-only transformers models
## Additional Arguments
- `window_length` (`int`, *optional*, defaults to 256): The length of the context window.
- `num_sink_tokens` (`int`, *optional*, defaults to 4): The number of sink tokens. See the original paper for more information.
## Output Type changes
- When `return_dict_in_generate=True`, `output.past_key_values` will be a `SinkCache` instance. `SinkCache` is defined
in `generate.py`, in this repository.
## Example usage
We can use the custom generation method in this repository like the the base `generate` from `transformers`:
```py
# requires `transformers>=4.52.0`
from transformers import AutoModelForCausalLM, AutoTokenizer
# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", device_map="auto")
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Using sink cache
gen_out = model.generate(
# usual `generate` arguments
**model_inputs,
do_sample=False,
max_new_tokens=100,
return_dict_in_generate=True,
# sink cache arguments (default `window_length=256`)
custom_generate="transformers-community/sink_cache",
trust_remote_code=True,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "sinkcache" in str(type(gen_out.past_key_values)).lower()
# ['user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
# eyes that sparkled with wonder. She had a soft, warm coat that shimmered like the morning sun, and her tail was
# always wagging in playful motions.\n\nOne day, while exploring the village, Luna noticed a curious sight: a young
# boy playing with a ball on the lake. She followed him closely, her heart racing']
```
Continuing the example above, we can confirm some properties of the `SinkCache`
```py
# `max_new_tokens` < `window_length` in the example above -> matches output with the default cache
gen_out = model.generate(
**model_inputs,
do_sample=False,
max_new_tokens=100,
return_dict_in_generate=True,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "dynamiccache" in str(type(gen_out.past_key_values)).lower()
# ['user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
# eyes that sparkled with wonder. She had a soft, warm coat that shimmered like the morning sun, and her tail was
# always wagging in playful motions.\n\nOne day, while exploring the village, Luna noticed a curious sight: a young
# boy playing with a ball on the lake. She followed him closely, her heart racing']
# if we set a smaller `window_length`, the story is less coherent after that point, but the used cache is also
# significantly smaller
gen_out = model.generate(
# usual `generate` arguments
**model_inputs,
do_sample=False,
max_new_tokens=100,
return_dict_in_generate=True,
# sink cache arguments
custom_generate="transformers-community/sink_cache",
trust_remote_code=True,
window_length=50,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
# ["user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
# heart. She loved exploring the village and playing with her friends.\n\nOne day, Luna noticed something unusual.
# She looked around and saw a shadow moving in the dark. She ran quickly, but she couldn't see the shadow. She
# thought maybe it was a ghost or something else.\n\nAs she was running, she heard a voice."]
```

3
config.json Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:660db3b73d788119c04535e48cf9be5f55bc3100841a718637ae695b442f27dd
size 726

231
custom_generate/generate.py Normal file
View File

@@ -0,0 +1,231 @@
import torch
from typing import Any, Dict, List, Optional, Tuple
from transformers import Cache, GenerationConfig
UNSUPPORTED_GENERATION_ARGS = [
"cache_implementation", # cache-related arguments, here we always use SinkCache
"cache_config",
"return_legacy_cache",
"num_beams", # beam search (and cousin techniques) are not supported
"compile_config", # SinkCache doesn't support torch.compile
"assistant_model", # it also doesn't support speculative decoding
]
class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
This class was copied from transformers 4.52.0, with minor modifications.
Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.
"""
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states
def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
return self.window_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.
Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
if cache_kwargs is None:
cache_kwargs = {}
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
else:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
]
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
"""Custom generate function for SinkCache.
Args:
model (`PreTrainedModel`):
The model to generate from.
window_length (`int`, *optional*, defaults to 256):
The length of the context window.
num_sink_tokens (`int`, *optional*, defaults to 4):
The number of sink tokens. See the original paper for more information.
"""
# 1. General sanity checks
# 1.a. A few arguments are not allowed, especially arguments that control caches.
generation_config = kwargs.get("generation_config")
default_global_generation_config = GenerationConfig()
default_model_generation_config = model.generation_config
for arg in UNSUPPORTED_GENERATION_ARGS:
has_custom_gen_config_arg = (
generation_config is not None
# = and not (match global default or match model-specific default)
and not (
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
)
)
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
if kwargs_has_arg or has_custom_gen_config_arg:
raise ValueError(
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
)
# 1.b. The model must be decoder-only
if model.config.is_encoder_decoder:
raise ValueError("This custom generate function only works with decoder-only models")
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
kwargs.pop("custom_generate", None)
# 2. Generate with SinkCache
# 2.a. prepare the cache, if it was not passed.
past_key_values = kwargs.pop("past_key_values", None)
if past_key_values is None:
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
elif not isinstance(past_key_values, SinkCache):
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
# 2.b. generate with the cache
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
return generation_outputs

3
generation_config.json Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2325da0f15bb848e018c5ae071b7943332e9f871d6b60e2ed22ca97d4cb993d2
size 239

151388
merges.txt Normal file

File diff suppressed because it is too large Load Diff

3
model.safetensors Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
size 1503300328

BIN
tokenizer.json (Stored with Git LFS) Normal file

Binary file not shown.

3
tokenizer_config.json Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d5d09f07b48c3086c508b30d1c9114bd1189145b74e982a265350c923acd8101
size 9732

BIN
vocab.json (Stored with Git LFS) Normal file

Binary file not shown.