初始化项目,由ModelHub XC社区提供模型
Model: transformers-community/sink_cache Source: Original Platform
This commit is contained in:
36
.gitattributes
vendored
Normal file
36
.gitattributes
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
*.json filter=lfs diff=lfs merge=lfs -text
|
||||
116
README.md
Normal file
116
README.md
Normal file
@@ -0,0 +1,116 @@
|
||||
---
|
||||
library_name: transformers
|
||||
tags:
|
||||
- custom_generate
|
||||
---
|
||||
|
||||
## Description
|
||||
Implementation of the KV cache introduced in the [Attention Sinks paper](https://huggingface.co/papers/2309.17453).
|
||||
It allows the model to generate beyond the length of its context window, without losing fluency in the conversation.
|
||||
This is done by always keeping the first few tokens ("sink tokens") in the KV cache, as models often pay a large
|
||||
amount of attention to them. As it discards past non-sink tokens, the model will lose the ability to generate tokens
|
||||
that depend on the context that was discarded. It's also a solution to contain the memory footprint of the KV cache.
|
||||
|
||||
This implementation matches the `SinkCache` class present in `transformers<4.53.0`.
|
||||
|
||||

|
||||
|
||||
<!-- TODO (joao): add `transformers chat` example -->
|
||||
|
||||
|
||||
## Base model
|
||||
- [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)
|
||||
|
||||
|
||||
## Model compatibility
|
||||
- Decoder-only transformers models
|
||||
|
||||
|
||||
## Additional Arguments
|
||||
- `window_length` (`int`, *optional*, defaults to 256): The length of the context window.
|
||||
- `num_sink_tokens` (`int`, *optional*, defaults to 4): The number of sink tokens. See the original paper for more information.
|
||||
|
||||
|
||||
## Output Type changes
|
||||
- When `return_dict_in_generate=True`, `output.past_key_values` will be a `SinkCache` instance. `SinkCache` is defined
|
||||
in `generate.py`, in this repository.
|
||||
|
||||
|
||||
## Example usage
|
||||
|
||||
We can use the custom generation method in this repository like the the base `generate` from `transformers`:
|
||||
|
||||
```py
|
||||
# requires `transformers>=4.52.0`
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# Preparing model, tokenizer, and model inputs
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", device_map="auto")
|
||||
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
# Using sink cache
|
||||
gen_out = model.generate(
|
||||
# usual `generate` arguments
|
||||
**model_inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=100,
|
||||
return_dict_in_generate=True,
|
||||
# sink cache arguments (default `window_length=256`)
|
||||
custom_generate="transformers-community/sink_cache",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
|
||||
assert "sinkcache" in str(type(gen_out.past_key_values)).lower()
|
||||
# ['user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
|
||||
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
|
||||
# eyes that sparkled with wonder. She had a soft, warm coat that shimmered like the morning sun, and her tail was
|
||||
# always wagging in playful motions.\n\nOne day, while exploring the village, Luna noticed a curious sight: a young
|
||||
# boy playing with a ball on the lake. She followed him closely, her heart racing']
|
||||
```
|
||||
|
||||
Continuing the example above, we can confirm some properties of the `SinkCache`
|
||||
|
||||
```py
|
||||
# `max_new_tokens` < `window_length` in the example above -> matches output with the default cache
|
||||
gen_out = model.generate(
|
||||
**model_inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=100,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
|
||||
assert "dynamiccache" in str(type(gen_out.past_key_values)).lower()
|
||||
# ['user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
|
||||
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
|
||||
# eyes that sparkled with wonder. She had a soft, warm coat that shimmered like the morning sun, and her tail was
|
||||
# always wagging in playful motions.\n\nOne day, while exploring the village, Luna noticed a curious sight: a young
|
||||
# boy playing with a ball on the lake. She followed him closely, her heart racing']
|
||||
|
||||
# if we set a smaller `window_length`, the story is less coherent after that point, but the used cache is also
|
||||
# significantly smaller
|
||||
gen_out = model.generate(
|
||||
# usual `generate` arguments
|
||||
**model_inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=100,
|
||||
return_dict_in_generate=True,
|
||||
# sink cache arguments
|
||||
custom_generate="transformers-community/sink_cache",
|
||||
trust_remote_code=True,
|
||||
window_length=50,
|
||||
)
|
||||
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
|
||||
# ["user\nTell me a story about a cat.\nassistant\n<think>\n\n</think>\n\nOnce upon a time, in a cozy village nestled
|
||||
# between rolling hills and a sparkling lake, there lived a cat named Luna. Luna was small and fluffy, with a curious
|
||||
# heart. She loved exploring the village and playing with her friends.\n\nOne day, Luna noticed something unusual.
|
||||
# She looked around and saw a shadow moving in the dark. She ran quickly, but she couldn't see the shadow. She
|
||||
# thought maybe it was a ghost or something else.\n\nAs she was running, she heard a voice."]
|
||||
```
|
||||
3
config.json
Normal file
3
config.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:660db3b73d788119c04535e48cf9be5f55bc3100841a718637ae695b442f27dd
|
||||
size 726
|
||||
231
custom_generate/generate.py
Normal file
231
custom_generate/generate.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import torch
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import Cache, GenerationConfig
|
||||
|
||||
|
||||
UNSUPPORTED_GENERATION_ARGS = [
|
||||
"cache_implementation", # cache-related arguments, here we always use SinkCache
|
||||
"cache_config",
|
||||
"return_legacy_cache",
|
||||
"num_beams", # beam search (and cousin techniques) are not supported
|
||||
"compile_config", # SinkCache doesn't support torch.compile
|
||||
"assistant_model", # it also doesn't support speculative decoding
|
||||
]
|
||||
|
||||
class SinkCache(Cache):
|
||||
"""
|
||||
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
||||
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
||||
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
||||
|
||||
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
||||
`[batch_size, num_heads, seq_len, head_dim]`.
|
||||
|
||||
This class was copied from transformers 4.52.0, with minor modifications.
|
||||
|
||||
Parameters:
|
||||
window_length (`int`):
|
||||
The length of the context window.
|
||||
num_sink_tokens (`int`):
|
||||
The number of sink tokens. See the original paper for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||
super().__init__()
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
self.window_length = window_length
|
||||
self.num_sink_tokens = num_sink_tokens
|
||||
self.cos_sin_rerotation_cache = {}
|
||||
self._cos_cache = None
|
||||
self._sin_cache = None
|
||||
|
||||
@staticmethod
|
||||
def _rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_key_rotary_pos_emb(
|
||||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
||||
return rotated_key_states
|
||||
|
||||
def _get_rerotation_cos_sin(
|
||||
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
|
||||
# Upcast to float32 temporarily for better accuracy
|
||||
cos = cos.to(torch.float32)
|
||||
sin = sin.to(torch.float32)
|
||||
|
||||
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
||||
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
||||
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
||||
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
||||
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
||||
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
||||
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
||||
|
||||
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
|
||||
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
||||
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
||||
)
|
||||
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
return 0
|
||||
return self.key_cache[layer_idx].shape[-2]
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
|
||||
return self.window_length
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||||
|
||||
Parameters:
|
||||
key_states (`torch.Tensor`):
|
||||
The new key states to cache.
|
||||
value_states (`torch.Tensor`):
|
||||
The new value states to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
||||
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
||||
rotation as the tokens are shifted.
|
||||
|
||||
Return:
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
||||
# with partially rotated position embeddings, like Phi or Persimmon.
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
sin = cache_kwargs.get("sin")
|
||||
cos = cache_kwargs.get("cos")
|
||||
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
||||
using_rope = cos is not None and sin is not None
|
||||
|
||||
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
||||
if using_rope and layer_idx == 0:
|
||||
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
||||
# after all RoPE models have a llama-like cache utilization.
|
||||
if cos.dim() == 2:
|
||||
self._cos_cache = cos
|
||||
self._sin_cache = sin
|
||||
else:
|
||||
if self._cos_cache is None:
|
||||
self._cos_cache = cos[0, ...]
|
||||
self._sin_cache = sin[0, ...]
|
||||
elif self._cos_cache.shape[0] < self.window_length:
|
||||
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
|
||||
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
|
||||
|
||||
# [bsz, num_heads, seq_len, head_dim]
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
# Empty cache
|
||||
self.key_cache.append(key_states)
|
||||
self.value_cache.append(value_states)
|
||||
|
||||
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
||||
# Growing cache
|
||||
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
||||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
||||
|
||||
else:
|
||||
# Shifting cache
|
||||
keys_to_keep = self.key_cache[layer_idx][
|
||||
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
||||
]
|
||||
|
||||
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
||||
if using_rope:
|
||||
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
||||
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
|
||||
)
|
||||
if partial_rotation_size is not None:
|
||||
keys_to_keep, keys_pass = (
|
||||
keys_to_keep[..., :partial_rotation_size],
|
||||
keys_to_keep[..., partial_rotation_size:],
|
||||
)
|
||||
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
||||
if partial_rotation_size is not None:
|
||||
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
||||
|
||||
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
||||
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
||||
|
||||
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
||||
values_to_keep = self.value_cache[layer_idx][
|
||||
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
||||
]
|
||||
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
|
||||
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
||||
"""Custom generate function for SinkCache.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`):
|
||||
The model to generate from.
|
||||
window_length (`int`, *optional*, defaults to 256):
|
||||
The length of the context window.
|
||||
num_sink_tokens (`int`, *optional*, defaults to 4):
|
||||
The number of sink tokens. See the original paper for more information.
|
||||
"""
|
||||
# 1. General sanity checks
|
||||
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
||||
generation_config = kwargs.get("generation_config")
|
||||
default_global_generation_config = GenerationConfig()
|
||||
default_model_generation_config = model.generation_config
|
||||
for arg in UNSUPPORTED_GENERATION_ARGS:
|
||||
has_custom_gen_config_arg = (
|
||||
generation_config is not None
|
||||
# = and not (match global default or match model-specific default)
|
||||
and not (
|
||||
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
||||
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
||||
)
|
||||
)
|
||||
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
||||
if kwargs_has_arg or has_custom_gen_config_arg:
|
||||
raise ValueError(
|
||||
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
||||
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
||||
)
|
||||
|
||||
# 1.b. The model must be decoder-only
|
||||
if model.config.is_encoder_decoder:
|
||||
raise ValueError("This custom generate function only works with decoder-only models")
|
||||
|
||||
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
||||
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
||||
kwargs.pop("custom_generate", None)
|
||||
|
||||
# 2. Generate with SinkCache
|
||||
# 2.a. prepare the cache, if it was not passed.
|
||||
past_key_values = kwargs.pop("past_key_values", None)
|
||||
if past_key_values is None:
|
||||
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
||||
elif not isinstance(past_key_values, SinkCache):
|
||||
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
||||
|
||||
# 2.b. generate with the cache
|
||||
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
||||
return generation_outputs
|
||||
3
generation_config.json
Normal file
3
generation_config.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2325da0f15bb848e018c5ae071b7943332e9f871d6b60e2ed22ca97d4cb993d2
|
||||
size 239
|
||||
151388
merges.txt
Normal file
151388
merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
|
||||
size 1503300328
|
||||
BIN
tokenizer.json
(Stored with Git LFS)
Normal file
BIN
tokenizer.json
(Stored with Git LFS)
Normal file
Binary file not shown.
3
tokenizer_config.json
Normal file
3
tokenizer_config.json
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d5d09f07b48c3086c508b30d1c9114bd1189145b74e982a265350c923acd8101
|
||||
size 9732
|
||||
BIN
vocab.json
(Stored with Git LFS)
Normal file
BIN
vocab.json
(Stored with Git LFS)
Normal file
Binary file not shown.
Reference in New Issue
Block a user